From 407c8d47028e9861a424e73372dc1ad9b2499f3a Mon Sep 17 00:00:00 2001 From: LXW <1264174498@qq.com> Date: Sun, 14 Jul 2024 11:23:47 +0800 Subject: [PATCH] (improvement)(Headless) Refactor ChatLayerService and SemanticLayerService (#1404) Co-authored-by: lxwcodemonkey --- .../service/impl/ChatQueryServiceImpl.java | 368 ++++++++++++++- .../headless/api/pojo/DataSetSchema.java | 14 +- .../headless/api/pojo/SemanticSchema.java | 8 - .../api/pojo/request/DimensionValueReq.java | 9 + .../api/pojo/request/QueryDimValueReq.java | 16 - .../headless/chat/parser/QueryTypeParser.java | 6 +- .../chat/query/BaseSemanticQuery.java | 10 +- .../headless/chat/query/SemanticQuery.java | 4 +- .../chat/query/llm/s2sql/LLMSqlQuery.java | 6 +- .../chat/query/rule/RuleSemanticQuery.java | 14 +- .../facade/service/ChatLayerService.java | 7 - .../facade/service/SemanticLayerService.java | 6 +- .../service/impl/S2ChatLayerService.java | 434 +----------------- .../service/impl/S2SemanticLayerService.java | 92 +++- .../server/web/rest/DimensionController.java | 8 +- .../headless/QueryDimensionTest.java | 8 +- 16 files changed, 496 insertions(+), 514 deletions(-) delete mode 100644 headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryDimValueReq.java diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java index ba8f0c915..dc4722441 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java @@ -13,29 +13,70 @@ import com.tencent.supersonic.chat.server.pojo.ParseContext; import com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor; import com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor; import com.tencent.supersonic.chat.server.service.AgentService; -import com.tencent.supersonic.chat.server.service.ChatContextService; import com.tencent.supersonic.chat.server.service.ChatManageService; import com.tencent.supersonic.chat.server.service.ChatQueryService; import com.tencent.supersonic.chat.server.util.ComponentFactory; import com.tencent.supersonic.chat.server.util.QueryReqConverter; +import com.tencent.supersonic.common.jsqlparser.FieldExpression; +import com.tencent.supersonic.common.jsqlparser.SqlAddHelper; +import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper; +import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper; +import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; +import com.tencent.supersonic.common.pojo.QueryColumn; +import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; +import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.util.BeanMapper; +import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.common.util.DateUtils; +import com.tencent.supersonic.common.util.JsonUtil; +import com.tencent.supersonic.headless.api.pojo.DataSetSchema; +import com.tencent.supersonic.headless.api.pojo.EntityInfo; +import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq; -import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq; +import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq; +import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.response.MapResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.api.pojo.response.QueryResult; +import com.tencent.supersonic.headless.api.pojo.response.QueryState; import com.tencent.supersonic.headless.api.pojo.response.SearchResult; +import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; +import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp; +import com.tencent.supersonic.headless.chat.query.QueryManager; +import com.tencent.supersonic.headless.chat.query.SemanticQuery; +import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery; import com.tencent.supersonic.headless.server.facade.service.ChatLayerService; import com.tencent.supersonic.headless.server.facade.service.RetrieveService; import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService; import lombok.extern.slf4j.Slf4j; +import net.sf.jsqlparser.expression.Expression; +import net.sf.jsqlparser.expression.LongValue; +import net.sf.jsqlparser.expression.StringValue; +import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; +import net.sf.jsqlparser.expression.operators.relational.EqualsTo; +import net.sf.jsqlparser.expression.operators.relational.GreaterThan; +import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals; +import net.sf.jsqlparser.expression.operators.relational.InExpression; +import net.sf.jsqlparser.expression.operators.relational.MinorThan; +import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals; +import net.sf.jsqlparser.expression.operators.relational.ParenthesedExpressionList; +import net.sf.jsqlparser.schema.Column; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.tuple.Pair; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import org.springframework.util.CollectionUtils; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; @Slf4j @@ -47,13 +88,11 @@ public class ChatQueryServiceImpl implements ChatQueryService { @Autowired private ChatLayerService chatLayerService; @Autowired + private SemanticLayerService semanticLayerService; + @Autowired private RetrieveService retrieveService; @Autowired private AgentService agentService; - @Autowired - private SemanticLayerService semanticLayerService; - @Autowired - private ChatContextService chatContextService; private List chatQueryParsers = ComponentFactory.getChatParsers(); private List chatQueryExecutors = ComponentFactory.getChatExecutors(); @@ -160,23 +199,328 @@ public class ChatQueryServiceImpl implements ChatQueryService { return executeContext; } + //mainly used for executing after revising filters,for example:"fans_cnt>=100000"->"fans_cnt>500000", + //"style='流行'"->"style in ['流行','爱国']" @Override public Object queryData(ChatQueryDataReq chatQueryDataReq, User user) throws Exception { Integer parseId = chatQueryDataReq.getParseId(); SemanticParseInfo parseInfo = chatManageService.getParseInfo( chatQueryDataReq.getQueryId(), parseId); - QueryDataReq queryData = new QueryDataReq(); - BeanMapper.mapper(chatQueryDataReq, queryData); - queryData.setParseInfo(parseInfo); - return chatLayerService.executeDirectQuery(queryData, user); + parseInfo = mergeSemanticParseInfo(parseInfo, chatQueryDataReq); + DataSetSchema dataSetSchema = semanticLayerService.getDataSetSchema(parseInfo.getDataSetId()); + + SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode()); + semanticQuery.setParseInfo(parseInfo); + + List fields = new ArrayList<>(); + if (Objects.nonNull(parseInfo.getSqlInfo()) + && StringUtils.isNotBlank(parseInfo.getSqlInfo().getCorrectedS2SQL())) { + String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL(); + fields = SqlSelectHelper.getAllFields(correctorSql); + } + if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode()) + && checkMetricReplace(fields, chatQueryDataReq.getMetrics())) { + //replace metrics + log.info("llm begin replace metrics!"); + SchemaElement metricToReplace = chatQueryDataReq.getMetrics().iterator().next(); + replaceMetrics(parseInfo, metricToReplace); + } else if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())) { + log.info("llm begin revise filters!"); + String correctorSql = reviseCorrectS2SQL(chatQueryDataReq, parseInfo); + parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql); + semanticQuery.setParseInfo(parseInfo); + SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq(); + SemanticTranslateResp explain = semanticLayerService.translate(semanticQueryReq, user); + parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL()); + } else { + log.info("rule begin replace metrics and revise filters!"); + //remove unvalid filters + validFilter(semanticQuery.getParseInfo().getDimensionFilters()); + validFilter(semanticQuery.getParseInfo().getMetricFilters()); + //init s2sql + semanticQuery.initS2Sql(dataSetSchema, user); + } + SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq(); + QueryResult queryResult = doExecution(semanticQueryReq, semanticQuery.getParseInfo(), user); + queryResult.setChatContext(semanticQuery.getParseInfo()); + SemanticLayerService semanticService = ContextUtils.getBean(SemanticLayerService.class); + EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, user); + queryResult.setEntityInfo(entityInfo); + return queryResult; + } + + private boolean checkMetricReplace(List oriFields, Set metrics) { + if (CollectionUtils.isEmpty(oriFields)) { + return false; + } + if (CollectionUtils.isEmpty(metrics)) { + return false; + } + List metricNames = metrics.stream().map(SchemaElement::getName).collect(Collectors.toList()); + return !oriFields.containsAll(metricNames); + } + + private String reviseCorrectS2SQL(ChatQueryDataReq queryData, SemanticParseInfo parseInfo) { + Map> filedNameToValueMap = new HashMap<>(); + Map> havingFiledNameToValueMap = new HashMap<>(); + + String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL(); + log.info("correctorSql before replacing:{}", correctorSql); + // get where filter and having filter + List whereExpressionList = SqlSelectHelper.getWhereExpressions(correctorSql); + List havingExpressionList = SqlSelectHelper.getHavingExpressions(correctorSql); + List addWhereConditions = new ArrayList<>(); + List addHavingConditions = new ArrayList<>(); + Set removeWhereFieldNames = new HashSet<>(); + Set removeHavingFieldNames = new HashSet<>(); + // replace where filter + updateFilters(whereExpressionList, queryData.getDimensionFilters(), + parseInfo.getDimensionFilters(), addWhereConditions, removeWhereFieldNames); + updateDateInfo(queryData, parseInfo, filedNameToValueMap, + whereExpressionList, addWhereConditions, removeWhereFieldNames); + correctorSql = SqlReplaceHelper.replaceValue(correctorSql, filedNameToValueMap); + correctorSql = SqlRemoveHelper.removeWhereCondition(correctorSql, removeWhereFieldNames); + // replace having filter + updateFilters(havingExpressionList, queryData.getDimensionFilters(), + parseInfo.getDimensionFilters(), addHavingConditions, removeHavingFieldNames); + correctorSql = SqlReplaceHelper.replaceHavingValue(correctorSql, havingFiledNameToValueMap); + correctorSql = SqlRemoveHelper.removeHavingCondition(correctorSql, removeHavingFieldNames); + + correctorSql = SqlAddHelper.addWhere(correctorSql, addWhereConditions); + correctorSql = SqlAddHelper.addHaving(correctorSql, addHavingConditions); + log.info("correctorSql after replacing:{}", correctorSql); + return correctorSql; + } + + private void replaceMetrics(SemanticParseInfo parseInfo, SchemaElement metric) { + List oriMetrics = parseInfo.getMetrics().stream() + .map(SchemaElement::getName).collect(Collectors.toList()); + String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL(); + log.info("before replaceMetrics:{}", correctorSql); + log.info("filteredMetrics:{},metrics:{}", oriMetrics, metric); + Map> fieldMap = new HashMap<>(); + if (!CollectionUtils.isEmpty(oriMetrics) && !oriMetrics.contains(metric.getName())) { + fieldMap.put(oriMetrics.get(0), Pair.of(metric.getName(), metric.getDefaultAgg())); + correctorSql = SqlReplaceHelper.replaceAggFields(correctorSql, fieldMap); + } + log.info("after replaceMetrics:{}", correctorSql); + parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql); + } + + private QueryResult doExecution(SemanticQueryReq semanticQueryReq, + SemanticParseInfo parseInfo, User user) throws Exception { + SemanticQueryResp queryResp = semanticLayerService.queryByReq(semanticQueryReq, user); + QueryResult queryResult = new QueryResult(); + if (queryResp != null) { + queryResult.setQueryAuthorization(queryResp.getQueryAuthorization()); + } + + String sql = queryResp == null ? null : queryResp.getSql(); + List> resultList = queryResp == null ? new ArrayList<>() + : queryResp.getResultList(); + List columns = queryResp == null ? new ArrayList<>() : queryResp.getColumns(); + queryResult.setQuerySql(sql); + queryResult.setQueryResults(resultList); + queryResult.setQueryColumns(columns); + queryResult.setQueryMode(parseInfo.getQueryMode()); + queryResult.setQueryState(QueryState.SUCCESS); + + return queryResult; + } + + private void updateDateInfo(ChatQueryDataReq queryData, SemanticParseInfo parseInfo, + Map> filedNameToValueMap, + List fieldExpressionList, + List addConditions, + Set removeFieldNames) { + if (Objects.isNull(queryData.getDateInfo())) { + return; + } + if (queryData.getDateInfo().getUnit() > 1) { + queryData.getDateInfo().setStartDate(DateUtils.getBeforeDate(queryData.getDateInfo().getUnit() + 1)); + queryData.getDateInfo().setEndDate(DateUtils.getBeforeDate(1)); + } + // startDate equals to endDate + for (FieldExpression fieldExpression : fieldExpressionList) { + if (TimeDimensionEnum.DAY.getChName().equals(fieldExpression.getFieldName())) { + // first remove,then add + removeFieldNames.add(TimeDimensionEnum.DAY.getChName()); + GreaterThanEquals greaterThanEquals = new GreaterThanEquals(); + addTimeFilters(queryData.getDateInfo().getStartDate(), greaterThanEquals, addConditions); + MinorThanEquals minorThanEquals = new MinorThanEquals(); + addTimeFilters(queryData.getDateInfo().getEndDate(), minorThanEquals, addConditions); + break; + } + } + for (FieldExpression fieldExpression : fieldExpressionList) { + for (QueryFilter queryFilter : queryData.getDimensionFilters()) { + if (queryFilter.getOperator().equals(FilterOperatorEnum.LIKE) + && FilterOperatorEnum.LIKE.getValue().toLowerCase().equals( + fieldExpression.getOperator().toLowerCase())) { + Map replaceMap = new HashMap<>(); + String preValue = fieldExpression.getFieldValue().toString(); + String curValue = queryFilter.getValue().toString(); + if (preValue.startsWith("%")) { + curValue = "%" + curValue; + } + if (preValue.endsWith("%")) { + curValue = curValue + "%"; + } + replaceMap.put(preValue, curValue); + filedNameToValueMap.put(fieldExpression.getFieldName(), replaceMap); + break; + } + } + } + parseInfo.setDateInfo(queryData.getDateInfo()); + } + + private void addTimeFilters(String date, + T comparisonExpression, + List addConditions) { + Column column = new Column(TimeDimensionEnum.DAY.getChName()); + StringValue stringValue = new StringValue(date); + comparisonExpression.setLeftExpression(column); + comparisonExpression.setRightExpression(stringValue); + addConditions.add(comparisonExpression); + } + + private void updateFilters(List fieldExpressionList, + Set metricFilters, + Set contextMetricFilters, + List addConditions, + Set removeFieldNames) { + if (org.apache.commons.collections.CollectionUtils.isEmpty(metricFilters)) { + return; + } + for (QueryFilter dslQueryFilter : metricFilters) { + for (FieldExpression fieldExpression : fieldExpressionList) { + if (fieldExpression.getFieldName() != null + && fieldExpression.getFieldName().contains(dslQueryFilter.getName())) { + removeFieldNames.add(dslQueryFilter.getName()); + if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.EQUALS)) { + EqualsTo equalsTo = new EqualsTo(); + addWhereFilters(dslQueryFilter, equalsTo, contextMetricFilters, addConditions); + } else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.GREATER_THAN_EQUALS)) { + GreaterThanEquals greaterThanEquals = new GreaterThanEquals(); + addWhereFilters(dslQueryFilter, greaterThanEquals, contextMetricFilters, addConditions); + } else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.GREATER_THAN)) { + GreaterThan greaterThan = new GreaterThan(); + addWhereFilters(dslQueryFilter, greaterThan, contextMetricFilters, addConditions); + } else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.MINOR_THAN_EQUALS)) { + MinorThanEquals minorThanEquals = new MinorThanEquals(); + addWhereFilters(dslQueryFilter, minorThanEquals, contextMetricFilters, addConditions); + } else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.MINOR_THAN)) { + MinorThan minorThan = new MinorThan(); + addWhereFilters(dslQueryFilter, minorThan, contextMetricFilters, addConditions); + } else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.IN)) { + InExpression inExpression = new InExpression(); + addWhereInFilters(dslQueryFilter, inExpression, contextMetricFilters, addConditions); + } + break; + } + } + } + } + + // add in condition to sql where condition + private void addWhereInFilters(QueryFilter dslQueryFilter, + InExpression inExpression, + Set contextMetricFilters, + List addConditions) { + Column column = new Column(dslQueryFilter.getName()); + ParenthesedExpressionList parenthesedExpressionList = new ParenthesedExpressionList<>(); + List valueList = JsonUtil.toList( + JsonUtil.toString(dslQueryFilter.getValue()), String.class); + if (org.apache.commons.collections.CollectionUtils.isEmpty(valueList)) { + return; + } + valueList.stream().forEach(o -> { + StringValue stringValue = new StringValue(o); + parenthesedExpressionList.add(stringValue); + }); + inExpression.setLeftExpression(column); + inExpression.setRightExpression(parenthesedExpressionList); + addConditions.add(inExpression); + contextMetricFilters.stream().forEach(o -> { + if (o.getName().equals(dslQueryFilter.getName())) { + o.setValue(dslQueryFilter.getValue()); + o.setOperator(dslQueryFilter.getOperator()); + } + }); + } + + // add where filter + private void addWhereFilters(QueryFilter dslQueryFilter, + T comparisonExpression, + Set contextMetricFilters, + List addConditions) { + String columnName = dslQueryFilter.getName(); + if (StringUtils.isNotBlank(dslQueryFilter.getFunction())) { + columnName = dslQueryFilter.getFunction() + "(" + dslQueryFilter.getName() + ")"; + } + if (Objects.isNull(dslQueryFilter.getValue())) { + return; + } + Column column = new Column(columnName); + comparisonExpression.setLeftExpression(column); + if (StringUtils.isNumeric(dslQueryFilter.getValue().toString())) { + LongValue longValue = new LongValue(Long.parseLong(dslQueryFilter.getValue().toString())); + comparisonExpression.setRightExpression(longValue); + } else { + StringValue stringValue = new StringValue(dslQueryFilter.getValue().toString()); + comparisonExpression.setRightExpression(stringValue); + } + addConditions.add(comparisonExpression); + contextMetricFilters.stream().forEach(o -> { + if (o.getName().equals(dslQueryFilter.getName())) { + o.setValue(dslQueryFilter.getValue()); + o.setOperator(dslQueryFilter.getOperator()); + } + }); + } + + private SemanticParseInfo mergeSemanticParseInfo(SemanticParseInfo parseInfo, + ChatQueryDataReq queryData) { + if (LLMSqlQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) { + return parseInfo; + } + if (!CollectionUtils.isEmpty(queryData.getDimensions())) { + parseInfo.setDimensions(queryData.getDimensions()); + } + if (!CollectionUtils.isEmpty(queryData.getMetrics())) { + parseInfo.setMetrics(queryData.getMetrics()); + } + if (!CollectionUtils.isEmpty(queryData.getDimensionFilters())) { + parseInfo.setDimensionFilters(queryData.getDimensionFilters()); + } + if (!CollectionUtils.isEmpty(queryData.getMetricFilters())) { + parseInfo.setMetricFilters(queryData.getMetricFilters()); + } + if (Objects.nonNull(queryData.getDateInfo())) { + parseInfo.setDateInfo(queryData.getDateInfo()); + } + return parseInfo; + } + + private void validFilter(Set filters) { + for (QueryFilter queryFilter : filters) { + if (Objects.isNull(queryFilter.getValue())) { + filters.remove(queryFilter); + } + if (queryFilter.getOperator().equals(FilterOperatorEnum.IN) && CollectionUtils.isEmpty( + JsonUtil.toList(JsonUtil.toString(queryFilter.getValue()), String.class))) { + filters.remove(queryFilter); + } + } } @Override - public Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception { + public Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) { Integer agentId = dimensionValueReq.getAgentId(); Agent agent = agentService.getAgent(agentId); dimensionValueReq.setDataSetIds(agent.getDataSetIds()); - return chatLayerService.queryDimensionValue(dimensionValueReq, user); + return semanticLayerService.queryDimensionValue(dimensionValueReq, user); } public void saveQueryResult(ChatExecuteReq chatExecuteReq, QueryResult queryResult) { diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DataSetSchema.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DataSetSchema.java index fc0c8363c..8ff0a620e 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DataSetSchema.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DataSetSchema.java @@ -1,14 +1,16 @@ package com.tencent.supersonic.headless.api.pojo; +import lombok.Data; +import org.apache.commons.collections.CollectionUtils; + import java.util.ArrayList; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; -import lombok.Data; -import org.apache.commons.collections.CollectionUtils; @Data public class DataSetSchema { @@ -57,6 +59,14 @@ public class DataSetSchema { } } + public Map getBizNameToName() { + List allElements = new ArrayList<>(); + allElements.addAll(getDimensions()); + allElements.addAll(getMetrics()); + return allElements.stream() + .collect(Collectors.toMap(SchemaElement::getBizName, SchemaElement::getName, (k1, k2) -> k1)); + } + public TimeDefaultConfig getTagTypeTimeDefaultConfig() { if (queryConfig == null) { return null; diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticSchema.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticSchema.java index 6f3eb77df..bb9669436 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticSchema.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticSchema.java @@ -150,14 +150,6 @@ public class SemanticSchema implements Serializable { return dataSets; } - public Map getBizNameToName(Long dataSetId) { - List allElements = new ArrayList<>(); - allElements.addAll(getDimensions(dataSetId)); - allElements.addAll(getMetrics(dataSetId)); - return allElements.stream() - .collect(Collectors.toMap(SchemaElement::getBizName, SchemaElement::getName, (k1, k2) -> k1)); - } - public Map getDataSetSchemaMap() { if (CollectionUtils.isEmpty(dataSetSchemaList)) { return new HashMap<>(); diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/DimensionValueReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/DimensionValueReq.java index 864e30266..7f466b5c7 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/DimensionValueReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/DimensionValueReq.java @@ -1,6 +1,8 @@ package com.tencent.supersonic.headless.api.pojo.request; +import com.tencent.supersonic.common.pojo.DateConf; import lombok.Data; +import org.apache.commons.lang3.StringUtils; import javax.validation.constraints.NotNull; import java.util.Set; @@ -21,4 +23,11 @@ public class DimensionValueReq { private Set dataSetIds; + private DateConf dateInfo = new DateConf(); + + private String dimensionBizName; + + public String getBizName() { + return StringUtils.isBlank(bizName) ? dimensionBizName : bizName; + } } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryDimValueReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryDimValueReq.java deleted file mode 100644 index b5030de83..000000000 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryDimValueReq.java +++ /dev/null @@ -1,16 +0,0 @@ -package com.tencent.supersonic.headless.api.pojo.request; - -import com.tencent.supersonic.common.pojo.DateConf; -import lombok.Data; -import lombok.ToString; - -@Data -@ToString -public class QueryDimValueReq { - - private Long modelId; - private String dimensionBizName; - private String value; - private DateConf dateInfo = new DateConf(); - -} diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java index b089c62db..6a5bf06a4 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java @@ -4,6 +4,7 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; +import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; @@ -35,7 +36,10 @@ public class QueryTypeParser implements SemanticParser { for (SemanticQuery semanticQuery : candidateQueries) { // 1.init S2SQL - semanticQuery.initS2Sql(chatQueryContext.getSemanticSchema(), user); + Long dataSetId = semanticQuery.getParseInfo().getDataSetId(); + DataSetSchema dataSetSchema = chatQueryContext.getSemanticSchema() + .getDataSetSchemaMap().get(dataSetId); + semanticQuery.initS2Sql(dataSetSchema, user); // 2.set queryType QueryType queryType = getQueryType(chatQueryContext, semanticQuery); semanticQuery.getParseInfo().setQueryType(queryType); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/BaseSemanticQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/BaseSemanticQuery.java index b479633ac..ec1a78ad3 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/BaseSemanticQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/BaseSemanticQuery.java @@ -6,8 +6,8 @@ import com.tencent.supersonic.common.pojo.Filter; import com.tencent.supersonic.common.pojo.Order; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq; import com.tencent.supersonic.headless.chat.parser.ParserConfig; @@ -43,8 +43,8 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable { return QueryReqBuilder.buildStructReq(parseInfo); } - protected void convertBizNameToName(SemanticSchema semanticSchema, QueryStructReq queryStructReq) { - Map bizNameToName = semanticSchema.getBizNameToName(queryStructReq.getDataSetId()); + protected void convertBizNameToName(DataSetSchema dataSetSchema, QueryStructReq queryStructReq) { + Map bizNameToName = dataSetSchema.getBizNameToName(); bizNameToName.putAll(TimeDimensionEnum.getNameToNameMap()); List orders = queryStructReq.getOrders(); @@ -74,14 +74,14 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable { } } - protected void initS2SqlByStruct(SemanticSchema semanticSchema) { + protected void initS2SqlByStruct(DataSetSchema dataSetSchema) { ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class); boolean s2sqlEnable = Boolean.valueOf(parserConfig.getParameterValue(PARSER_S2SQL_ENABLE)); if (!s2sqlEnable) { return; } QueryStructReq queryStructReq = convertQueryStruct(); - convertBizNameToName(semanticSchema, queryStructReq); + convertBizNameToName(dataSetSchema, queryStructReq); QuerySqlReq querySQLReq = queryStructReq.convert(); parseInfo.getSqlInfo().setParsedS2SQL(querySQLReq.getSql()); parseInfo.getSqlInfo().setCorrectedS2SQL(querySQLReq.getSql()); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/SemanticQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/SemanticQuery.java index a4859b8b1..bd1faa7a5 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/SemanticQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/SemanticQuery.java @@ -1,8 +1,8 @@ package com.tencent.supersonic.headless.chat.query; import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import org.apache.calcite.sql.parser.SqlParseException; @@ -15,7 +15,7 @@ public interface SemanticQuery { SemanticQueryReq buildSemanticQueryReq() throws SqlParseException; - void initS2Sql(SemanticSchema semanticSchema, User user); + void initS2Sql(DataSetSchema dataSetSchema, User user); SemanticParseInfo getParseInfo(); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMSqlQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMSqlQuery.java index 35e3e275e..fdce9b235 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMSqlQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMSqlQuery.java @@ -1,12 +1,12 @@ package com.tencent.supersonic.headless.chat.query.llm.s2sql; import com.tencent.supersonic.auth.api.authentication.pojo.User; -import com.tencent.supersonic.headless.api.pojo.SemanticSchema; +import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.SqlInfo; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; -import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder; import com.tencent.supersonic.headless.chat.query.QueryManager; import com.tencent.supersonic.headless.chat.query.llm.LLMSemanticQuery; +import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Component; @@ -31,7 +31,7 @@ public class LLMSqlQuery extends LLMSemanticQuery { } @Override - public void initS2Sql(SemanticSchema semanticSchema, User user) { + public void initS2Sql(DataSetSchema dataSetSchema, User user) { SqlInfo sqlInfo = parseInfo.getSqlInfo(); sqlInfo.setCorrectedS2SQL(sqlInfo.getParsedS2SQL()); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java index 6e7b1be65..b5c26b9b6 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java @@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.chat.query.rule; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; +import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; @@ -13,9 +14,13 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq; import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.chat.ChatQueryContext; -import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder; import com.tencent.supersonic.headless.chat.query.BaseSemanticQuery; import com.tencent.supersonic.headless.chat.query.QueryManager; +import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder; +import lombok.ToString; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; + import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -24,9 +29,6 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Set; import java.util.stream.Collectors; -import lombok.ToString; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.StringUtils; @Slf4j @ToString @@ -44,8 +46,8 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery { } @Override - public void initS2Sql(SemanticSchema semanticSchema, User user) { - initS2SqlByStruct(semanticSchema); + public void initS2Sql(DataSetSchema dataSetSchema, User user) { + initS2SqlByStruct(dataSetSchema); } public void fillParseInfo(ChatQueryContext chatQueryContext) { diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/ChatLayerService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/ChatLayerService.java index 915c188c4..3a6b3df97 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/ChatLayerService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/ChatLayerService.java @@ -2,15 +2,12 @@ package com.tencent.supersonic.headless.server.facade.service; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.headless.api.pojo.SqlEvaluation; -import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq; -import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq; import com.tencent.supersonic.headless.api.pojo.request.QueryMapReq; import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq; import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp; import com.tencent.supersonic.headless.api.pojo.response.MapResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; -import com.tencent.supersonic.headless.api.pojo.response.QueryResult; /***dd * SemanticLayerService for query and search @@ -21,10 +18,6 @@ public interface ChatLayerService { ParseResp performParsing(QueryNLReq queryNLReq); - QueryResult executeDirectQuery(QueryDataReq queryData, User user) throws Exception; - - Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception; - MapInfoResp map(QueryMapReq queryMapReq); void correct(QuerySqlReq querySqlReq, User user); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/SemanticLayerService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/SemanticLayerService.java index 27e569756..574a6db0a 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/SemanticLayerService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/SemanticLayerService.java @@ -4,11 +4,11 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.EntityInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq; +import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; -import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp; import com.tencent.supersonic.headless.api.pojo.response.ItemResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; +import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp; import java.util.List; @@ -23,7 +23,7 @@ public interface SemanticLayerService { SemanticQueryResp queryByReq(SemanticQueryReq queryReq, User user) throws Exception; - SemanticQueryResp queryDimValue(QueryDimValueReq queryDimValueReq, User user); + SemanticQueryResp queryDimensionValue(DimensionValueReq dimensionValueReq, User user); EntityInfo getEntityInfo(SemanticParseInfo parseInfo, DataSetSchema dataSetSchema, User user); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2ChatLayerService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2ChatLayerService.java index a6d264e57..2036171da 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2ChatLayerService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2ChatLayerService.java @@ -2,20 +2,8 @@ package com.tencent.supersonic.headless.server.facade.service.impl; import com.google.common.collect.Lists; import com.tencent.supersonic.auth.api.authentication.pojo.User; -import com.tencent.supersonic.common.jsqlparser.FieldExpression; -import com.tencent.supersonic.common.jsqlparser.SqlAddHelper; -import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper; -import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper; -import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; -import com.tencent.supersonic.common.pojo.QueryColumn; -import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; -import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.common.util.DateUtils; -import com.tencent.supersonic.common.util.JsonUtil; -import com.tencent.supersonic.headless.api.pojo.DataSetSchema; -import com.tencent.supersonic.headless.api.pojo.EntityInfo; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; @@ -25,60 +13,27 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.SqlEvaluation; import com.tencent.supersonic.headless.api.pojo.SqlInfo; -import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq; -import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq; -import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq; -import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq; -import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; -import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; import com.tencent.supersonic.headless.api.pojo.request.QueryMapReq; +import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq; import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; -import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.response.DataSetMapInfo; import com.tencent.supersonic.headless.api.pojo.response.DataSetResp; -import com.tencent.supersonic.headless.api.pojo.response.DimensionResp; -import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp; import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp; import com.tencent.supersonic.headless.api.pojo.response.MapResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; -import com.tencent.supersonic.headless.api.pojo.response.QueryResult; -import com.tencent.supersonic.headless.api.pojo.response.QueryState; -import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.corrector.GrammarCorrector; import com.tencent.supersonic.headless.chat.corrector.SchemaCorrector; -import com.tencent.supersonic.headless.chat.knowledge.HanlpMapResult; -import com.tencent.supersonic.headless.chat.knowledge.KnowledgeBaseService; -import com.tencent.supersonic.headless.chat.knowledge.SearchService; import com.tencent.supersonic.headless.chat.knowledge.builder.BaseWordBuilder; -import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper; -import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper; -import com.tencent.supersonic.headless.chat.query.QueryManager; import com.tencent.supersonic.headless.chat.query.SemanticQuery; -import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery; import com.tencent.supersonic.headless.server.facade.service.ChatLayerService; -import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService; -import com.tencent.supersonic.headless.server.utils.ChatWorkflowEngine; import com.tencent.supersonic.headless.server.pojo.MetaFilter; +import com.tencent.supersonic.headless.server.utils.ChatWorkflowEngine; import com.tencent.supersonic.headless.server.utils.ComponentFactory; import com.tencent.supersonic.headless.server.web.service.DataSetService; import com.tencent.supersonic.headless.server.web.service.SchemaService; import lombok.extern.slf4j.Slf4j; -import net.sf.jsqlparser.expression.Expression; -import net.sf.jsqlparser.expression.LongValue; -import net.sf.jsqlparser.expression.StringValue; -import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; -import net.sf.jsqlparser.expression.operators.relational.EqualsTo; -import net.sf.jsqlparser.expression.operators.relational.GreaterThan; -import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals; -import net.sf.jsqlparser.expression.operators.relational.InExpression; -import net.sf.jsqlparser.expression.operators.relational.MinorThan; -import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals; -import net.sf.jsqlparser.expression.operators.relational.ParenthesedExpressionList; -import net.sf.jsqlparser.schema.Column; import org.apache.commons.collections.CollectionUtils; -import org.apache.commons.lang3.StringUtils; -import org.apache.commons.lang3.tuple.Pair; import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; @@ -86,7 +41,6 @@ import org.springframework.stereotype.Service; import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; @@ -97,13 +51,9 @@ import java.util.stream.Collectors; @Service @Slf4j public class S2ChatLayerService implements ChatLayerService { - @Autowired - private SemanticLayerService semanticLayerService; @Autowired private SchemaService schemaService; @Autowired - private KnowledgeBaseService knowledgeBaseService; - @Autowired private DataSetService dataSetService; @Autowired private ChatWorkflowEngine chatWorkflowEngine; @@ -166,386 +116,6 @@ public class S2ChatLayerService implements ChatLayerService { return queryCtx; } - //mainly used for executing after revising filters,for example:"fans_cnt>=100000"->"fans_cnt>500000", - //"style='流行'"->"style in ['流行','爱国']" - @Override - public QueryResult executeDirectQuery(QueryDataReq queryData, User user) throws Exception { - SemanticParseInfo parseInfo = getSemanticParseInfo(queryData); - SemanticSchema semanticSchema = schemaService.getSemanticSchema(); - - SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode()); - semanticQuery.setParseInfo(parseInfo); - - List fields = new ArrayList<>(); - if (Objects.nonNull(parseInfo.getSqlInfo()) - && StringUtils.isNotBlank(parseInfo.getSqlInfo().getCorrectedS2SQL())) { - String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL(); - fields = SqlSelectHelper.getAllFields(correctorSql); - } - if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode()) - && checkMetricReplace(fields, queryData.getMetrics())) { - //replace metrics - log.info("llm begin replace metrics!"); - SchemaElement metricToReplace = queryData.getMetrics().iterator().next(); - replaceMetrics(parseInfo, metricToReplace); - } else if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())) { - log.info("llm begin revise filters!"); - String correctorSql = reviseCorrectS2SQL(queryData, parseInfo); - parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql); - semanticQuery.setParseInfo(parseInfo); - SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq(); - SemanticTranslateResp explain = semanticLayerService.translate(semanticQueryReq, user); - parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL()); - } else { - log.info("rule begin replace metrics and revise filters!"); - //remove unvalid filters - validFilter(semanticQuery.getParseInfo().getDimensionFilters()); - validFilter(semanticQuery.getParseInfo().getMetricFilters()); - //init s2sql - semanticQuery.initS2Sql(semanticSchema, user); - QueryNLReq queryNLReq = new QueryNLReq(); - queryNLReq.setQueryFilters(new QueryFilters()); - queryNLReq.setUser(user); - } - SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq(); - QueryResult queryResult = doExecution(semanticQueryReq, semanticQuery.getParseInfo(), user); - queryResult.setChatContext(semanticQuery.getParseInfo()); - DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(parseInfo.getDataSetId()); - SemanticLayerService semanticService = ContextUtils.getBean(SemanticLayerService.class); - EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, user); - queryResult.setEntityInfo(entityInfo); - return queryResult; - } - - private QueryResult doExecution(SemanticQueryReq semanticQueryReq, - SemanticParseInfo parseInfo, User user) throws Exception { - SemanticQueryResp queryResp = semanticLayerService.queryByReq(semanticQueryReq, user); - QueryResult queryResult = new QueryResult(); - if (queryResp != null) { - queryResult.setQueryAuthorization(queryResp.getQueryAuthorization()); - } - - String sql = queryResp == null ? null : queryResp.getSql(); - List> resultList = queryResp == null ? new ArrayList<>() - : queryResp.getResultList(); - List columns = queryResp == null ? new ArrayList<>() : queryResp.getColumns(); - queryResult.setQuerySql(sql); - queryResult.setQueryResults(resultList); - queryResult.setQueryColumns(columns); - queryResult.setQueryMode(parseInfo.getQueryMode()); - queryResult.setQueryState(QueryState.SUCCESS); - - return queryResult; - } - - private boolean checkMetricReplace(List oriFields, Set metrics) { - if (CollectionUtils.isEmpty(oriFields)) { - return false; - } - if (CollectionUtils.isEmpty(metrics)) { - return false; - } - List metricNames = metrics.stream().map(SchemaElement::getName).collect(Collectors.toList()); - return !oriFields.containsAll(metricNames); - } - - private String reviseCorrectS2SQL(QueryDataReq queryData, SemanticParseInfo parseInfo) { - Map> filedNameToValueMap = new HashMap<>(); - Map> havingFiledNameToValueMap = new HashMap<>(); - - String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL(); - log.info("correctorSql before replacing:{}", correctorSql); - // get where filter and having filter - List whereExpressionList = SqlSelectHelper.getWhereExpressions(correctorSql); - List havingExpressionList = SqlSelectHelper.getHavingExpressions(correctorSql); - List addWhereConditions = new ArrayList<>(); - List addHavingConditions = new ArrayList<>(); - Set removeWhereFieldNames = new HashSet<>(); - Set removeHavingFieldNames = new HashSet<>(); - // replace where filter - updateFilters(whereExpressionList, queryData.getDimensionFilters(), - parseInfo.getDimensionFilters(), addWhereConditions, removeWhereFieldNames); - updateDateInfo(queryData, parseInfo, filedNameToValueMap, - whereExpressionList, addWhereConditions, removeWhereFieldNames); - correctorSql = SqlReplaceHelper.replaceValue(correctorSql, filedNameToValueMap); - correctorSql = SqlRemoveHelper.removeWhereCondition(correctorSql, removeWhereFieldNames); - // replace having filter - updateFilters(havingExpressionList, queryData.getDimensionFilters(), - parseInfo.getDimensionFilters(), addHavingConditions, removeHavingFieldNames); - correctorSql = SqlReplaceHelper.replaceHavingValue(correctorSql, havingFiledNameToValueMap); - correctorSql = SqlRemoveHelper.removeHavingCondition(correctorSql, removeHavingFieldNames); - - correctorSql = SqlAddHelper.addWhere(correctorSql, addWhereConditions); - correctorSql = SqlAddHelper.addHaving(correctorSql, addHavingConditions); - log.info("correctorSql after replacing:{}", correctorSql); - return correctorSql; - } - - private void replaceMetrics(SemanticParseInfo parseInfo, SchemaElement metric) { - List oriMetrics = parseInfo.getMetrics().stream() - .map(SchemaElement::getName).collect(Collectors.toList()); - String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL(); - log.info("before replaceMetrics:{}", correctorSql); - log.info("filteredMetrics:{},metrics:{}", oriMetrics, metric); - Map> fieldMap = new HashMap<>(); - if (CollectionUtils.isNotEmpty(oriMetrics) && !oriMetrics.contains(metric.getName())) { - fieldMap.put(oriMetrics.get(0), Pair.of(metric.getName(), metric.getDefaultAgg())); - correctorSql = SqlReplaceHelper.replaceAggFields(correctorSql, fieldMap); - } - log.info("after replaceMetrics:{}", correctorSql); - parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql); - } - - private void updateDateInfo(QueryDataReq queryData, SemanticParseInfo parseInfo, - Map> filedNameToValueMap, - List fieldExpressionList, - List addConditions, - Set removeFieldNames) { - if (Objects.isNull(queryData.getDateInfo())) { - return; - } - if (queryData.getDateInfo().getUnit() > 1) { - queryData.getDateInfo().setStartDate(DateUtils.getBeforeDate(queryData.getDateInfo().getUnit() + 1)); - queryData.getDateInfo().setEndDate(DateUtils.getBeforeDate(1)); - } - // startDate equals to endDate - for (FieldExpression fieldExpression : fieldExpressionList) { - if (TimeDimensionEnum.DAY.getChName().equals(fieldExpression.getFieldName())) { - // first remove,then add - removeFieldNames.add(TimeDimensionEnum.DAY.getChName()); - GreaterThanEquals greaterThanEquals = new GreaterThanEquals(); - addTimeFilters(queryData.getDateInfo().getStartDate(), greaterThanEquals, addConditions); - MinorThanEquals minorThanEquals = new MinorThanEquals(); - addTimeFilters(queryData.getDateInfo().getEndDate(), minorThanEquals, addConditions); - break; - } - } - for (FieldExpression fieldExpression : fieldExpressionList) { - for (QueryFilter queryFilter : queryData.getDimensionFilters()) { - if (queryFilter.getOperator().equals(FilterOperatorEnum.LIKE) - && FilterOperatorEnum.LIKE.getValue().toLowerCase().equals( - fieldExpression.getOperator().toLowerCase())) { - Map replaceMap = new HashMap<>(); - String preValue = fieldExpression.getFieldValue().toString(); - String curValue = queryFilter.getValue().toString(); - if (preValue.startsWith("%")) { - curValue = "%" + curValue; - } - if (preValue.endsWith("%")) { - curValue = curValue + "%"; - } - replaceMap.put(preValue, curValue); - filedNameToValueMap.put(fieldExpression.getFieldName(), replaceMap); - break; - } - } - } - parseInfo.setDateInfo(queryData.getDateInfo()); - } - - private void addTimeFilters(String date, - T comparisonExpression, - List addConditions) { - Column column = new Column(TimeDimensionEnum.DAY.getChName()); - StringValue stringValue = new StringValue(date); - comparisonExpression.setLeftExpression(column); - comparisonExpression.setRightExpression(stringValue); - addConditions.add(comparisonExpression); - } - - private void updateFilters(List fieldExpressionList, - Set metricFilters, - Set contextMetricFilters, - List addConditions, - Set removeFieldNames) { - if (CollectionUtils.isEmpty(metricFilters)) { - return; - } - for (QueryFilter dslQueryFilter : metricFilters) { - for (FieldExpression fieldExpression : fieldExpressionList) { - if (fieldExpression.getFieldName() != null - && fieldExpression.getFieldName().contains(dslQueryFilter.getName())) { - removeFieldNames.add(dslQueryFilter.getName()); - if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.EQUALS)) { - EqualsTo equalsTo = new EqualsTo(); - addWhereFilters(dslQueryFilter, equalsTo, contextMetricFilters, addConditions); - } else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.GREATER_THAN_EQUALS)) { - GreaterThanEquals greaterThanEquals = new GreaterThanEquals(); - addWhereFilters(dslQueryFilter, greaterThanEquals, contextMetricFilters, addConditions); - } else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.GREATER_THAN)) { - GreaterThan greaterThan = new GreaterThan(); - addWhereFilters(dslQueryFilter, greaterThan, contextMetricFilters, addConditions); - } else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.MINOR_THAN_EQUALS)) { - MinorThanEquals minorThanEquals = new MinorThanEquals(); - addWhereFilters(dslQueryFilter, minorThanEquals, contextMetricFilters, addConditions); - } else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.MINOR_THAN)) { - MinorThan minorThan = new MinorThan(); - addWhereFilters(dslQueryFilter, minorThan, contextMetricFilters, addConditions); - } else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.IN)) { - InExpression inExpression = new InExpression(); - addWhereInFilters(dslQueryFilter, inExpression, contextMetricFilters, addConditions); - } - break; - } - } - } - } - - // add in condition to sql where condition - private void addWhereInFilters(QueryFilter dslQueryFilter, - InExpression inExpression, - Set contextMetricFilters, - List addConditions) { - Column column = new Column(dslQueryFilter.getName()); - ParenthesedExpressionList parenthesedExpressionList = new ParenthesedExpressionList<>(); - List valueList = JsonUtil.toList( - JsonUtil.toString(dslQueryFilter.getValue()), String.class); - if (CollectionUtils.isEmpty(valueList)) { - return; - } - valueList.stream().forEach(o -> { - StringValue stringValue = new StringValue(o); - parenthesedExpressionList.add(stringValue); - }); - inExpression.setLeftExpression(column); - inExpression.setRightExpression(parenthesedExpressionList); - addConditions.add(inExpression); - contextMetricFilters.stream().forEach(o -> { - if (o.getName().equals(dslQueryFilter.getName())) { - o.setValue(dslQueryFilter.getValue()); - o.setOperator(dslQueryFilter.getOperator()); - } - }); - } - - // add where filter - private void addWhereFilters(QueryFilter dslQueryFilter, - T comparisonExpression, - Set contextMetricFilters, - List addConditions) { - String columnName = dslQueryFilter.getName(); - if (StringUtils.isNotBlank(dslQueryFilter.getFunction())) { - columnName = dslQueryFilter.getFunction() + "(" + dslQueryFilter.getName() + ")"; - } - if (Objects.isNull(dslQueryFilter.getValue())) { - return; - } - Column column = new Column(columnName); - comparisonExpression.setLeftExpression(column); - if (StringUtils.isNumeric(dslQueryFilter.getValue().toString())) { - LongValue longValue = new LongValue(Long.parseLong(dslQueryFilter.getValue().toString())); - comparisonExpression.setRightExpression(longValue); - } else { - StringValue stringValue = new StringValue(dslQueryFilter.getValue().toString()); - comparisonExpression.setRightExpression(stringValue); - } - addConditions.add(comparisonExpression); - contextMetricFilters.stream().forEach(o -> { - if (o.getName().equals(dslQueryFilter.getName())) { - o.setValue(dslQueryFilter.getValue()); - o.setOperator(dslQueryFilter.getOperator()); - } - }); - } - - private SemanticParseInfo getSemanticParseInfo(QueryDataReq queryData) { - SemanticParseInfo parseInfo = queryData.getParseInfo(); - if (LLMSqlQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) { - return parseInfo; - } - if (CollectionUtils.isNotEmpty(queryData.getDimensions())) { - parseInfo.setDimensions(queryData.getDimensions()); - } - if (CollectionUtils.isNotEmpty(queryData.getMetrics())) { - parseInfo.setMetrics(queryData.getMetrics()); - } - if (CollectionUtils.isNotEmpty(queryData.getDimensionFilters())) { - parseInfo.setDimensionFilters(queryData.getDimensionFilters()); - } - if (CollectionUtils.isNotEmpty(queryData.getMetricFilters())) { - parseInfo.setMetricFilters(queryData.getMetricFilters()); - } - if (Objects.nonNull(queryData.getDateInfo())) { - parseInfo.setDateInfo(queryData.getDateInfo()); - } - return parseInfo; - } - - private void validFilter(Set filters) { - for (QueryFilter queryFilter : filters) { - if (Objects.isNull(queryFilter.getValue())) { - filters.remove(queryFilter); - } - if (queryFilter.getOperator().equals(FilterOperatorEnum.IN) && CollectionUtils.isEmpty( - JsonUtil.toList(JsonUtil.toString(queryFilter.getValue()), String.class))) { - filters.remove(queryFilter); - } - } - } - - @Override - public Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception { - SemanticQueryResp semanticQueryResp = new SemanticQueryResp(); - DimensionResp dimensionResp = schemaService.getDimension(dimensionValueReq.getElementID()); - Set dataSetIds = dimensionValueReq.getDataSetIds(); - dimensionValueReq.setModelId(dimensionResp.getModelId()); - List dimensionValues = getDimensionValues(dimensionValueReq, dataSetIds); - // if the search results is null,search dimensionValue from database - if (CollectionUtils.isEmpty(dimensionValues)) { - semanticQueryResp = queryDatabase(dimensionValueReq, user); - return semanticQueryResp; - } - List columns = new ArrayList<>(); - QueryColumn queryColumn = new QueryColumn(); - queryColumn.setNameEn(dimensionValueReq.getBizName()); - queryColumn.setShowType("CATEGORY"); - queryColumn.setAuthorized(true); - queryColumn.setType("CHAR"); - columns.add(queryColumn); - List> resultList = new ArrayList<>(); - dimensionValues.stream().forEach(o -> { - Map map = new HashMap<>(); - map.put(dimensionValueReq.getBizName(), o); - resultList.add(map); - }); - semanticQueryResp.setColumns(columns); - semanticQueryResp.setResultList(resultList); - return semanticQueryResp; - } - - private List getDimensionValues(DimensionValueReq dimensionValueReq, Set dataSetIds) { - //if value is null ,then search from NATURE_TO_VALUES - if (StringUtils.isBlank(dimensionValueReq.getValue())) { - return SearchService.getDimensionValue(dimensionValueReq); - } - Map> modelIdToDataSetIds = new HashMap<>(); - modelIdToDataSetIds.put(dimensionValueReq.getModelId(), new ArrayList<>(dataSetIds)); - //search from prefixSearch - List hanlpMapResultList = knowledgeBaseService.prefixSearch(dimensionValueReq.getValue(), - 2000, modelIdToDataSetIds, dataSetIds); - HanlpHelper.transLetterOriginal(hanlpMapResultList); - return hanlpMapResultList.stream() - .filter(o -> { - for (String nature : o.getNatures()) { - Long elementID = NatureHelper.getElementID(nature); - if (dimensionValueReq.getElementID().equals(elementID)) { - return true; - } - } - return false; - }) - .map(mapResult -> mapResult.getName()) - .collect(Collectors.toList()); - } - - private SemanticQueryResp queryDatabase(DimensionValueReq dimensionValueReq, User user) { - QueryDimValueReq queryDimValueReq = new QueryDimValueReq(); - queryDimValueReq.setValue(dimensionValueReq.getValue()); - queryDimValueReq.setModelId(dimensionValueReq.getModelId()); - queryDimValueReq.setDimensionBizName(dimensionValueReq.getBizName()); - return semanticLayerService.queryDimValue(queryDimValueReq, user); - } - public void correct(QuerySqlReq querySqlReq, User user) { SemanticParseInfo semanticParseInfo = correctSqlReq(querySqlReq, user); querySqlReq.setSql(semanticParseInfo.getSqlInfo().getCorrectedS2SQL()); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2SemanticLayerService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2SemanticLayerService.java index 3fff180fc..921cc9e60 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2SemanticLayerService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2SemanticLayerService.java @@ -4,6 +4,7 @@ import com.google.common.collect.Lists; import com.google.common.collect.Sets; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.common.pojo.DateConf; +import com.tencent.supersonic.common.pojo.QueryColumn; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.pojo.enums.TaskStatusEnum; @@ -19,7 +20,8 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig; import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig; -import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq; +import com.tencent.supersonic.headless.api.pojo.enums.SemanticType; +import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq; import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; @@ -27,11 +29,17 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq; import com.tencent.supersonic.headless.api.pojo.request.SchemaFilterReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.response.DimensionResp; -import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp; import com.tencent.supersonic.headless.api.pojo.response.ItemResp; import com.tencent.supersonic.headless.api.pojo.response.ModelResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticSchemaResp; +import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp; +import com.tencent.supersonic.headless.chat.knowledge.HanlpMapResult; +import com.tencent.supersonic.headless.chat.knowledge.KnowledgeBaseService; +import com.tencent.supersonic.headless.chat.knowledge.MapResult; +import com.tencent.supersonic.headless.chat.knowledge.SearchService; +import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper; +import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper; import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder; import com.tencent.supersonic.headless.core.cache.QueryCache; import com.tencent.supersonic.headless.core.executor.QueryExecutor; @@ -57,6 +65,7 @@ import org.springframework.stereotype.Service; import java.time.LocalDate; import java.util.ArrayList; +import java.util.HashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; @@ -77,6 +86,7 @@ public class S2SemanticLayerService implements SemanticLayerService { private final SchemaService schemaService; private final SemanticTranslator semanticTranslator; private final MetricDrillDownChecker metricDrillDownChecker; + private final KnowledgeBaseService knowledgeBaseService; private QueryCache queryCache = ComponentFactory.getQueryCache(); private List queryExecutors = ComponentFactory.getQueryExecutors(); @@ -88,7 +98,8 @@ public class S2SemanticLayerService implements SemanticLayerService { DataSetService dataSetService, SchemaService schemaService, SemanticTranslator semanticTranslator, - MetricDrillDownChecker metricDrillDownChecker) { + MetricDrillDownChecker metricDrillDownChecker, + KnowledgeBaseService knowledgeBaseService) { this.statUtils = statUtils; this.queryUtils = queryUtils; this.queryReqConverter = queryReqConverter; @@ -97,6 +108,7 @@ public class S2SemanticLayerService implements SemanticLayerService { this.schemaService = schemaService; this.semanticTranslator = semanticTranslator; this.metricDrillDownChecker = metricDrillDownChecker; + this.knowledgeBaseService = knowledgeBaseService; } public DataSetSchema getDataSetSchema(Long id) { @@ -175,12 +187,74 @@ public class S2SemanticLayerService implements SemanticLayerService { } @Override - @SneakyThrows - public SemanticQueryResp queryDimValue(QueryDimValueReq queryDimValueReq, User user) { - QuerySqlReq querySqlReq = buildQuerySqlReq(queryDimValueReq); + public SemanticQueryResp queryDimensionValue(DimensionValueReq dimensionValueReq, User user) { + SemanticQueryResp semanticQueryResp = new SemanticQueryResp(); + DimensionResp dimensionResp = getDimension(dimensionValueReq); + Set dataSetIds = dimensionValueReq.getDataSetIds(); + dimensionValueReq.setModelId(dimensionResp.getModelId()); + List dimensionValues = getDimensionValuesFromDict(dimensionValueReq, dataSetIds); + // if the search results is null,search dimensionValue from database + if (CollectionUtils.isEmpty(dimensionValues)) { + semanticQueryResp = getDimensionValuesFromDb(dimensionValueReq, user); + return semanticQueryResp; + } + List columns = new ArrayList<>(); + QueryColumn queryColumn = new QueryColumn(); + queryColumn.setNameEn(dimensionValueReq.getBizName()); + queryColumn.setShowType(SemanticType.CATEGORY.name()); + queryColumn.setAuthorized(true); + queryColumn.setType("CHAR"); + columns.add(queryColumn); + List> resultList = new ArrayList<>(); + dimensionValues.stream().forEach(o -> { + Map map = new HashMap<>(); + map.put(dimensionValueReq.getBizName(), o); + resultList.add(map); + }); + semanticQueryResp.setColumns(columns); + semanticQueryResp.setResultList(resultList); + return semanticQueryResp; + } + + private List getDimensionValuesFromDict(DimensionValueReq dimensionValueReq, Set dataSetIds) { + //if value is null ,then search from NATURE_TO_VALUES + if (StringUtils.isBlank(dimensionValueReq.getValue())) { + return SearchService.getDimensionValue(dimensionValueReq); + } + Map> modelIdToDataSetIds = new HashMap<>(); + modelIdToDataSetIds.put(dimensionValueReq.getModelId(), new ArrayList<>(dataSetIds)); + //search from prefixSearch + List hanlpMapResultList = knowledgeBaseService.prefixSearch(dimensionValueReq.getValue(), + 2000, modelIdToDataSetIds, dataSetIds); + HanlpHelper.transLetterOriginal(hanlpMapResultList); + return hanlpMapResultList.stream() + .filter(o -> { + for (String nature : o.getNatures()) { + Long elementID = NatureHelper.getElementID(nature); + if (dimensionValueReq.getElementID().equals(elementID)) { + return true; + } + } + return false; + }) + .map(MapResult::getName) + .collect(Collectors.toList()); + } + + private SemanticQueryResp getDimensionValuesFromDb(DimensionValueReq dimensionValueReq, User user) { + QuerySqlReq querySqlReq = buildQuerySqlReq(dimensionValueReq); return queryByReq(querySqlReq, user); } + private DimensionResp getDimension(DimensionValueReq dimensionValueReq) { + DimensionResp dimensionResp = schemaService.getDimension(dimensionValueReq.getElementID()); + if (dimensionResp == null) { + return schemaService.getDimension(dimensionValueReq.getBizName(), + dimensionValueReq.getModelId()); + } + return dimensionResp; + } + public EntityInfo getEntityInfo(SemanticParseInfo parseInfo, DataSetSchema dataSetSchema, User user) { if (parseInfo != null && parseInfo.getDataSetId() != null && parseInfo.getDataSetId() > 0) { EntityInfo entityInfo = getEntityBasicInfo(dataSetSchema); @@ -291,10 +365,10 @@ public class S2SemanticLayerService implements SemanticLayerService { return schemaFilterReq; } - private QuerySqlReq buildQuerySqlReq(QueryDimValueReq queryDimValueReq) { + private QuerySqlReq buildQuerySqlReq(DimensionValueReq queryDimValueReq) { QuerySqlReq querySqlReq = new QuerySqlReq(); List modelResps = schemaService.getModelList(Lists.newArrayList(queryDimValueReq.getModelId())); - DimensionResp dimensionResp = schemaService.getDimension(queryDimValueReq.getDimensionBizName(), + DimensionResp dimensionResp = schemaService.getDimension(queryDimValueReq.getBizName(), queryDimValueReq.getModelId()); ModelResp modelResp = modelResps.get(0); String sql = String.format("select distinct %s from %s where 1=1", @@ -306,7 +380,7 @@ public class S2SemanticLayerService implements SemanticLayerService { queryDimValueReq.getDateInfo().getEndDate()); } if (StringUtils.isNotBlank(queryDimValueReq.getValue())) { - sql += " AND " + queryDimValueReq.getDimensionBizName() + " LIKE '%" + queryDimValueReq.getValue() + "%'"; + sql += " AND " + queryDimValueReq.getBizName() + " LIKE '%" + queryDimValueReq.getValue() + "%'"; } querySqlReq.setModelIds(Sets.newHashSet(queryDimValueReq.getModelId())); querySqlReq.setSql(sql); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/web/rest/DimensionController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/web/rest/DimensionController.java index bbe300916..66fbc3a64 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/web/rest/DimensionController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/web/rest/DimensionController.java @@ -7,15 +7,15 @@ import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; import com.tencent.supersonic.common.pojo.enums.SensitiveLevelEnum; import com.tencent.supersonic.headless.api.pojo.DimValueMap; import com.tencent.supersonic.headless.api.pojo.request.DimensionReq; +import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq; import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq; import com.tencent.supersonic.headless.api.pojo.request.PageDimensionReq; -import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq; import com.tencent.supersonic.headless.api.pojo.response.DimensionResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; +import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService; import com.tencent.supersonic.headless.server.pojo.DimensionFilter; import com.tencent.supersonic.headless.server.pojo.MetaFilter; import com.tencent.supersonic.headless.server.web.service.DimensionService; -import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.DeleteMapping; import org.springframework.web.bind.annotation.GetMapping; @@ -121,11 +121,11 @@ public class DimensionController { } @PostMapping("/queryDimValue") - public SemanticQueryResp queryDimValue(@RequestBody QueryDimValueReq queryDimValueReq, + public SemanticQueryResp queryDimValue(@RequestBody DimensionValueReq dimensionValueReq, HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); - return queryService.queryDimValue(queryDimValueReq, user); + return queryService.queryDimensionValue(dimensionValueReq, user); } @DeleteMapping("deleteDimension/{id}") diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/QueryDimensionTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/QueryDimensionTest.java index 5b2909fae..117f221a6 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/QueryDimensionTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/QueryDimensionTest.java @@ -1,7 +1,7 @@ package com.tencent.supersonic.headless; import com.tencent.supersonic.auth.api.authentication.pojo.User; -import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq; +import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq; import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; import org.junit.Assert; import org.junit.jupiter.api.Test; @@ -10,11 +10,11 @@ public class QueryDimensionTest extends BaseTest { @Test public void testQueryDimValue() { - QueryDimValueReq queryDimValueReq = new QueryDimValueReq(); + DimensionValueReq queryDimValueReq = new DimensionValueReq(); queryDimValueReq.setModelId(1L); - queryDimValueReq.setDimensionBizName("department"); + queryDimValueReq.setBizName("department"); - SemanticQueryResp queryResp = semanticLayerService.queryDimValue(queryDimValueReq, User.getFakeUser()); + SemanticQueryResp queryResp = semanticLayerService.queryDimensionValue(queryDimValueReq, User.getFakeUser()); Assert.assertNotNull(queryResp.getResultList()); Assert.assertEquals(4, queryResp.getResultList().size()); }