diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/ResultProcessor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/ResultProcessor.java index e53bef063..19e39704b 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/ResultProcessor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/ResultProcessor.java @@ -5,5 +5,4 @@ package com.tencent.supersonic.chat.server.processor; */ public interface ResultProcessor { - } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/ExecuteResultProcessor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/ExecuteResultProcessor.java index 98c95997e..cc2b32ed2 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/ExecuteResultProcessor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/ExecuteResultProcessor.java @@ -5,7 +5,8 @@ import com.tencent.supersonic.chat.server.processor.ResultProcessor; import com.tencent.supersonic.headless.api.pojo.response.QueryResult; /** - * A ExecuteResultProcessor wraps things up before returning results to users in execute stage. + * A ExecuteResultProcessor wraps things up before returning + * execution results to the users. */ public interface ExecuteResultProcessor extends ResultProcessor { diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/EntityInfoProcessor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/EntityInfoProcessor.java deleted file mode 100644 index 8ed027fc4..000000000 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/EntityInfoProcessor.java +++ /dev/null @@ -1,43 +0,0 @@ -package com.tencent.supersonic.chat.server.processor.parse; - -import com.tencent.supersonic.chat.server.pojo.ParseContext; -import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.headless.api.pojo.DataSetSchema; -import com.tencent.supersonic.headless.api.pojo.EntityInfo; -import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.headless.api.pojo.response.ParseResp; -import com.tencent.supersonic.headless.chat.query.QueryManager; -import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService; -import org.springframework.util.CollectionUtils; - -import java.util.List; - -/** - * EntityInfoProcessor fills core attributes of an entity so that - * users get to know which entity is parsed out. - */ -public class EntityInfoProcessor implements ParseResultProcessor { - - @Override - public void process(ParseContext parseContext, ParseResp parseResp) { - List selectedParses = parseResp.getSelectedParses(); - if (CollectionUtils.isEmpty(selectedParses)) { - return; - } - selectedParses.forEach(parseInfo -> { - String queryMode = parseInfo.getQueryMode(); - if (QueryManager.containsRuleQuery(queryMode) || "PLAIN".equals(queryMode)) { - return; - } - - //1. set entity info - SemanticLayerService semanticService = ContextUtils.getBean(SemanticLayerService.class); - DataSetSchema dataSetSchema = semanticService.getDataSetSchema(parseInfo.getDataSetId()); - EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, parseContext.getUser()); - if (QueryManager.isTagQuery(queryMode) - || QueryManager.isMetricQuery(queryMode)) { - parseInfo.setEntityInfo(entityInfo); - } - }); - } -} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseResultProcessor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseResultProcessor.java index 4760d6c30..33bb53d02 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseResultProcessor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseResultProcessor.java @@ -1,9 +1,14 @@ package com.tencent.supersonic.chat.server.processor.parse; import com.tencent.supersonic.chat.server.pojo.ParseContext; +import com.tencent.supersonic.chat.server.processor.ResultProcessor; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; -public interface ParseResultProcessor { +/** + * A ParseResultProcessor wraps things up before returning + * parsing results to the users. + */ +public interface ParseResultProcessor extends ResultProcessor { void process(ParseContext parseContext, ParseResp parseResp); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java index 339c42720..23701977e 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java @@ -213,7 +213,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { if (Objects.nonNull(parseInfo.getSqlInfo()) && StringUtils.isNotBlank(parseInfo.getSqlInfo().getCorrectedS2SQL())) { String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL(); - fields = SqlSelectHelper.getAllFields(correctorSql); + fields = SqlSelectHelper.getAllSelectFields(correctorSql); } if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode()) && checkMetricReplace(fields, chatQueryDataReq.getMetrics())) { diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java index 90db3c064..6acd6ac7d 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java @@ -244,7 +244,7 @@ public class SqlSelectHelper { return plainSelects; } - public static List getAllFields(String sql) { + public static List getAllSelectFields(String sql) { List plainSelects = getPlainSelects(getPlainSelect(sql)); Set results = new HashSet<>(); for (PlainSelect plainSelect : plainSelects) { diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlValidHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlValidHelper.java index a32979a54..e378cc580 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlValidHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlValidHelper.java @@ -28,8 +28,8 @@ public class SqlValidHelper { } //2. all fields - List thisAllFields = SqlSelectHelper.getAllFields(thisSql); - List otherAllFields = SqlSelectHelper.getAllFields(otherSql); + List thisAllFields = SqlSelectHelper.getAllSelectFields(thisSql); + List otherAllFields = SqlSelectHelper.getAllSelectFields(otherSql); if (!CollectionUtils.isEqualCollection(thisAllFields, otherAllFields)) { return false; diff --git a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlAddHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlAddHelperTest.java index cd5ad541b..ccec0c751 100644 --- a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlAddHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlAddHelperTest.java @@ -24,12 +24,12 @@ class SqlAddHelperTest { String sql = "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' " + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1"; sql = SqlAddHelper.addWhere(sql, "column_a", 123444555); - List selectFields = SqlSelectHelper.getAllFields(sql); + List selectFields = SqlSelectHelper.getAllSelectFields(sql); Assert.assertEquals(selectFields.contains("column_a"), true); sql = SqlAddHelper.addWhere(sql, "column_b", "123456666"); - selectFields = SqlSelectHelper.getAllFields(sql); + selectFields = SqlSelectHelper.getAllSelectFields(sql); Assert.assertEquals(selectFields.contains("column_b"), true); diff --git a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelperTest.java index 35a8870f3..d9d1fa411 100644 --- a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelperTest.java @@ -131,55 +131,55 @@ class SqlSelectHelperTest { @Test void testGetAllFields() { - List allFields = SqlSelectHelper.getAllFields( + List allFields = SqlSelectHelper.getAllSelectFields( "SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08'" + " AND user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1"); Assert.assertEquals(allFields.size(), 6); - allFields = SqlSelectHelper.getAllFields( + allFields = SqlSelectHelper.getAllSelectFields( "SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date >= '2023-08-08'" + " AND user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1"); Assert.assertEquals(allFields.size(), 6); - allFields = SqlSelectHelper.getAllFields( + allFields = SqlSelectHelper.getAllSelectFields( "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' and 用户 = 'alice'" + " and 发布日期 ='11' group by 部门 limit 1"); Assert.assertEquals(allFields.size(), 5); - allFields = SqlSelectHelper.getAllFields( + allFields = SqlSelectHelper.getAllSelectFields( "SELECT user_name FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " + "sys_imp_date >= '2023-08-04' GROUP BY user_name ORDER BY sum(pv) DESC LIMIT 10 "); Assert.assertEquals(allFields.size(), 3); - allFields = SqlSelectHelper.getAllFields( + allFields = SqlSelectHelper.getAllSelectFields( "SELECT user_name FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " + "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"); Assert.assertEquals(allFields.size(), 3); - allFields = SqlSelectHelper.getAllFields( + allFields = SqlSelectHelper.getAllSelectFields( "SELECT department, user_id, field_a FROM s2 WHERE " + "(user_id = 'alice' AND publish_date = '11') and sys_imp_date " + "= '2023-08-08' ORDER BY pv DESC LIMIT 1"); Assert.assertEquals(allFields.size(), 6); - allFields = SqlSelectHelper.getAllFields( + allFields = SqlSelectHelper.getAllSelectFields( "SELECT * FROM CSpider WHERE (评分 < (SELECT min(评分) FROM CSpider WHERE 语种 = '英文' ))" + " AND 数据日期 = '2023-10-12'"); Assert.assertEquals(allFields.size(), 3); - allFields = SqlSelectHelper.getAllFields("SELECT sum(销量) / (SELECT sum(销量) FROM 营销 " + allFields = SqlSelectHelper.getAllSelectFields("SELECT sum(销量) / (SELECT sum(销量) FROM 营销 " + "WHERE MONTH(数据日期) = 9) FROM 营销 WHERE 国家中文名 = '中国' AND MONTH(数据日期) = 9"); Assert.assertEquals(allFields.size(), 3); - allFields = SqlSelectHelper.getAllFields( + allFields = SqlSelectHelper.getAllSelectFields( "SELECT 用户, 页面 FROM 超音数用户部门 GROUP BY 用户, 页面 ORDER BY count(*) DESC"); Assert.assertEquals(allFields.size(), 2); diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticSchema.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticSchema.java index bb9669436..14fdd731e 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticSchema.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticSchema.java @@ -87,7 +87,7 @@ public class SemanticSchema implements Serializable { public List getMetrics() { List metrics = new ArrayList<>(); - dataSetSchemaList.stream().forEach(d -> metrics.addAll(d.getMetrics())); + dataSetSchemaList.forEach(d -> metrics.addAll(d.getMetrics())); return metrics; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java index 43932745d..1ebd9ec61 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java @@ -57,7 +57,7 @@ public class ChatQueryContext { public List getCandidateQueries() { ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class); - int parseShowCount = Integer.valueOf(parserConfig.getParameterValue(ParserConfig.PARSER_SHOW_COUNT)); + int parseShowCount = Integer.parseInt(parserConfig.getParameterValue(ParserConfig.PARSER_SHOW_COUNT)); candidateQueries = candidateQueries.stream() .sorted(Comparator.comparing(semanticQuery -> semanticQuery.getParseInfo().getScore(), Comparator.reverseOrder())) diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2SemanticLayerService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2SemanticLayerService.java index 8753f81ae..c5d788430 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2SemanticLayerService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2SemanticLayerService.java @@ -478,8 +478,7 @@ public class S2SemanticLayerService implements SemanticLayerService { SemanticQueryResp queryResultWithColumns = getQueryResultWithSchemaResp(entityInfo, dataSetSchema, user); if (queryResultWithColumns != null) { - if (!org.springframework.util.CollectionUtils.isEmpty(queryResultWithColumns.getResultList()) - && queryResultWithColumns.getResultList().size() > 0) { + if (!CollectionUtils.isEmpty(queryResultWithColumns.getResultList())) { Map result = queryResultWithColumns.getResultList().get(0); for (Map.Entry entry : result.entrySet()) { String entryKey = getEntryKey(entry); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/EntityInfoProcessor.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/EntityInfoProcessor.java new file mode 100644 index 000000000..fe69f76dc --- /dev/null +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/EntityInfoProcessor.java @@ -0,0 +1,31 @@ +package com.tencent.supersonic.headless.server.processor; + +import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.headless.api.pojo.DataSetSchema; +import com.tencent.supersonic.headless.api.pojo.EntityInfo; +import com.tencent.supersonic.headless.api.pojo.response.ParseResp; +import com.tencent.supersonic.headless.chat.ChatQueryContext; +import com.tencent.supersonic.headless.chat.query.QueryManager; +import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService; + +/** + * EntityInfoProcessor fills core attributes of an entity so that + * users get to know which entity is parsed out. + */ +public class EntityInfoProcessor implements ResultProcessor { + + @Override + public void process(ParseResp parseResp, ChatQueryContext chatQueryContext) { + parseResp.getSelectedParses().forEach(parseInfo -> { + String queryMode = parseInfo.getQueryMode(); + if (!QueryManager.isTagQuery(queryMode) && !QueryManager.isMetricQuery(queryMode)) { + return; + } + + SemanticLayerService semanticService = ContextUtils.getBean(SemanticLayerService.class); + DataSetSchema dataSetSchema = semanticService.getDataSetSchema(parseInfo.getDataSetId()); + EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, chatQueryContext.getUser()); + parseInfo.setEntityInfo(entityInfo); + }); + } +} diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/ParseInfoProcessor.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/ParseInfoProcessor.java index 493c5075c..d3a7521e9 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/ParseInfoProcessor.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/ParseInfoProcessor.java @@ -8,21 +8,19 @@ import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.SqlInfo; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.chat.ChatQueryContext; -import com.tencent.supersonic.headless.chat.query.SemanticQuery; -import com.tencent.supersonic.headless.server.service.SchemaService; +import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; import org.springframework.util.CollectionUtils; -import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; import java.util.List; @@ -40,94 +38,79 @@ public class ParseInfoProcessor implements ResultProcessor { @Override public void process(ParseResp parseResp, ChatQueryContext chatQueryContext) { - List candidateQueries = chatQueryContext.getCandidateQueries(); - if (CollectionUtils.isEmpty(candidateQueries)) { - return; - } - List candidateParses = candidateQueries.stream() - .map(SemanticQuery::getParseInfo).collect(Collectors.toList()); - candidateParses.forEach(this::updateParseInfo); + parseResp.getSelectedParses().forEach(this::updateParseInfo); } public void updateParseInfo(SemanticParseInfo parseInfo) { SqlInfo sqlInfo = parseInfo.getSqlInfo(); - String correctS2SQL = sqlInfo.getCorrectedS2SQL(); - if (StringUtils.isBlank(correctS2SQL)) { + String s2SQL = sqlInfo.getCorrectedS2SQL(); + if (StringUtils.isBlank(s2SQL)) { return; } - List expressions = SqlSelectHelper.getFilterExpression(correctS2SQL); + List expressions = SqlSelectHelper.getFilterExpression(s2SQL); - //set dataInfo + //extract date filter from S2SQL try { - if (!org.apache.commons.collections.CollectionUtils.isEmpty(expressions)) { - DateConf dateInfo = getDateInfo(expressions); - if (dateInfo != null && parseInfo.getDateInfo() == null) { - parseInfo.setDateInfo(dateInfo); - } + if (parseInfo.getDateInfo() == null && !CollectionUtils.isEmpty(expressions)) { + parseInfo.setDateInfo(extractDateFilter(expressions)); } } catch (Exception e) { - log.error("set dateInfo error :", e); + log.error("failed to extract date range:", e); } - if (correctS2SQL.equals(sqlInfo.getParsedS2SQL())) { - return; - } - //set filter + //extract dimension filters from S2SQL Long dataSetId = parseInfo.getDataSetId(); + SemanticLayerService semanticLayerService = ContextUtils.getBean(SemanticLayerService.class); + DataSetSchema dsSchema = semanticLayerService.getDataSetSchema(dataSetId); + try { - Map fieldNameToElement = getNameToElement(dataSetId); - List result = getDimensionFilter(fieldNameToElement, expressions); - parseInfo.getDimensionFilters().addAll(result); + Map fieldNameToElement = getNameToElement(dsSchema); + parseInfo.getDimensionFilters().addAll(extractDimensionFilter(fieldNameToElement, expressions)); } catch (Exception e) { - log.error("set dimensionFilter error :", e); + log.error("failed to extract dimension filters:", e); } - SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); - if (Objects.isNull(semanticSchema)) { - return; - } - List allFields = getFieldsExceptDate(SqlSelectHelper.getAllFields(sqlInfo.getCorrectedS2SQL())); - Set metrics = getElements(dataSetId, allFields, semanticSchema.getMetrics()); + //extract metrics from S2SQL + List allFields = filterDateField(SqlSelectHelper.getAllSelectFields(s2SQL)); + Set metrics = matchSchemaElements(allFields, dsSchema.getMetrics()); parseInfo.setMetrics(metrics); + + //extract dimensions from S2SQL if (QueryType.METRIC.equals(parseInfo.getQueryType())) { - List groupByFields = SqlSelectHelper.getGroupByFields(sqlInfo.getCorrectedS2SQL()); - List groupByDimensions = getFieldsExceptDate(groupByFields); - parseInfo.setDimensions(getElements(dataSetId, groupByDimensions, semanticSchema.getDimensions())); + List groupByFields = SqlSelectHelper.getGroupByFields(s2SQL); + List groupByDimensions = filterDateField(groupByFields); + parseInfo.setDimensions(matchSchemaElements(groupByDimensions, dsSchema.getDimensions())); } else if (QueryType.DETAIL.equals(parseInfo.getQueryType())) { - List selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getCorrectedS2SQL()); - List selectDimensions = getFieldsExceptDate(selectFields); - parseInfo.setDimensions(getElements(dataSetId, selectDimensions, semanticSchema.getDimensions())); + List selectFields = SqlSelectHelper.getSelectFields(s2SQL); + List selectDimensions = filterDateField(selectFields); + parseInfo.setDimensions(matchSchemaElements(selectDimensions, dsSchema.getDimensions())); } } - private Set getElements(Long dataSetId, List allFields, List elements) { + private Set matchSchemaElements(List allFields, Set elements) { return elements.stream() .filter(schemaElement -> { if (CollectionUtils.isEmpty(schemaElement.getAlias())) { - return dataSetId.equals(schemaElement.getDataSet()) && allFields.contains( - schemaElement.getName()); + return allFields.contains(schemaElement.getName()); } Set allFieldsSet = new HashSet<>(allFields); Set aliasSet = new HashSet<>(schemaElement.getAlias()); List intersection = allFieldsSet.stream() .filter(aliasSet::contains).collect(Collectors.toList()); - return dataSetId.equals(schemaElement.getDataSet()) && (allFields.contains( - schemaElement.getName()) || !CollectionUtils.isEmpty(intersection)); + return allFields.contains(schemaElement.getName()) + || !CollectionUtils.isEmpty(intersection); } ).collect(Collectors.toSet()); } - private List getFieldsExceptDate(List allFields) { - if (org.springframework.util.CollectionUtils.isEmpty(allFields)) { - return new ArrayList<>(); - } + private List filterDateField(List allFields) { return allFields.stream() .filter(entry -> !TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(entry)) .collect(Collectors.toList()); } - private List getDimensionFilter(Map fieldNameToElement, - List fieldExpressions) { + private List extractDimensionFilter(Map fieldNameToElement, + List fieldExpressions) { List result = Lists.newArrayList(); for (FieldExpression expression : fieldExpressions) { QueryFilter dimensionFilter = new QueryFilter(); @@ -148,7 +131,7 @@ public class ParseInfoProcessor implements ResultProcessor { return result; } - private DateConf getDateInfo(List fieldExpressions) { + private DateConf extractDateFilter(List fieldExpressions) { List dateExpressions = fieldExpressions.stream() .filter(expression -> TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(expression.getFieldName())) .collect(Collectors.toList()); @@ -193,10 +176,9 @@ public class ParseInfoProcessor implements ResultProcessor { return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue()); } - protected Map getNameToElement(Long dataSetId) { - SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); - List dimensions = semanticSchema.getDimensions(dataSetId); - List metrics = semanticSchema.getMetrics(dataSetId); + protected Map getNameToElement(DataSetSchema dsSchema) { + Set dimensions = dsSchema.getDimensions(); + Set metrics = dsSchema.getMetrics(); List allElements = Lists.newArrayList(); allElements.addAll(dimensions); @@ -214,7 +196,7 @@ public class ParseInfoProcessor implements ResultProcessor { } return result.stream(); }) - .collect(Collectors.toMap(pair -> pair.getLeft(), pair -> pair.getRight(), + .collect(Collectors.toMap(Pair::getLeft, Pair::getRight, (value1, value2) -> value2)); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java index 5b7966274..e9adb24aa 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java @@ -33,10 +33,10 @@ import java.util.stream.Collectors; public class ChatWorkflowEngine { private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); - private List schemaMappers = ComponentFactory.getSchemaMappers(); - private List semanticParsers = ComponentFactory.getSemanticParsers(); - private List semanticCorrectors = ComponentFactory.getSemanticCorrectors(); - private List resultProcessors = ComponentFactory.getResultProcessors(); + private final List schemaMappers = ComponentFactory.getSchemaMappers(); + private final List semanticParsers = ComponentFactory.getSemanticParsers(); + private final List semanticCorrectors = ComponentFactory.getSemanticCorrectors(); + private final List resultProcessors = ComponentFactory.getResultProcessors(); public void execute(ChatQueryContext queryCtx, ParseResp parseResult) { queryCtx.setChatWorkflowState(ChatWorkflowState.MAPPING); @@ -44,7 +44,7 @@ public class ChatWorkflowEngine { switch (queryCtx.getChatWorkflowState()) { case MAPPING: performMapping(queryCtx); - if (queryCtx.getMapInfo().getMatchedDataSetInfos().size() == 0) { + if (queryCtx.getMapInfo().getMatchedDataSetInfos().isEmpty()) { parseResult.setState(ParseResp.ParseState.FAILED); parseResult.setErrorMsg("No semantic entities can be mapped against user question."); queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED); @@ -54,7 +54,7 @@ public class ChatWorkflowEngine { break; case PARSING: performParsing(queryCtx); - if (queryCtx.getCandidateQueries().size() == 0) { + if (queryCtx.getCandidateQueries().isEmpty()) { parseResult.setState(ParseResp.ParseState.FAILED); parseResult.setErrorMsg("No semantic queries can be parsed out."); queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryReqConverter.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryReqConverter.java index 2565800b9..c6fcadd0c 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryReqConverter.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryReqConverter.java @@ -78,7 +78,7 @@ public class QueryReqConverter { querySQLReq.setSql(SqlReplaceHelper.replaceAggAliasOrderItem(querySQLReq.getSql())); log.debug("replaceOrderAggSameAlias {} -> {}", reqSql, querySQLReq.getSql()); //4.build MetricTables - List allFields = SqlSelectHelper.getAllFields(querySQLReq.getSql()); + List allFields = SqlSelectHelper.getAllSelectFields(querySQLReq.getSql()); List metricSchemas = getMetrics(semanticSchemaResp, allFields); List metrics = metricSchemas.stream().map(m -> m.getBizName()).collect(Collectors.toList()); QueryStructReq queryStructReq = new QueryStructReq(); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryStructUtils.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryStructUtils.java index 08e6035a7..4cd7be617 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryStructUtils.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryStructUtils.java @@ -124,7 +124,7 @@ public class QueryStructUtils { } public Set getResName(QuerySqlReq querySqlReq) { - return new HashSet<>(SqlSelectHelper.getAllFields(querySqlReq.getSql())); + return new HashSet<>(SqlSelectHelper.getAllSelectFields(querySqlReq.getSql())); } public Set getBizNameFromSql(QuerySqlReq querySqlReq, diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/StatUtils.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/StatUtils.java index e1c5d674e..b248a8ec0 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/StatUtils.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/StatUtils.java @@ -141,7 +141,7 @@ public class StatUtils { public void initSqlStatInfo(QuerySqlReq querySqlReq, User facadeUser) { QueryStat queryStatInfo = new QueryStat(); List aggFields = SqlSelectHelper.getAggregateFields(querySqlReq.getSql()); - List allFields = SqlSelectHelper.getAllFields(querySqlReq.getSql()); + List allFields = SqlSelectHelper.getAllSelectFields(querySqlReq.getSql()); List dimensions = allFields.stream().filter(aggFields::contains).collect(Collectors.toList()); String userName = getUserName(facadeUser); diff --git a/launchers/standalone/src/main/resources/META-INF/spring.factories b/launchers/standalone/src/main/resources/META-INF/spring.factories index 21fe6b2e2..a2421f8d4 100644 --- a/launchers/standalone/src/main/resources/META-INF/spring.factories +++ b/launchers/standalone/src/main/resources/META-INF/spring.factories @@ -47,7 +47,8 @@ com.tencent.supersonic.headless.core.cache.QueryCache=\ ### headless-server SPIs com.tencent.supersonic.headless.server.processor.ResultProcessor=\ - com.tencent.supersonic.headless.server.processor.ParseInfoProcessor + com.tencent.supersonic.headless.server.processor.ParseInfoProcessor,\ + com.tencent.supersonic.headless.server.processor.EntityInfoProcessor ### chat-server SPIs @@ -66,7 +67,6 @@ com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer=\ com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor=\ com.tencent.supersonic.chat.server.processor.parse.QueryRecommendProcessor,\ - com.tencent.supersonic.chat.server.processor.parse.EntityInfoProcessor,\ com.tencent.supersonic.chat.server.processor.parse.TimeCostProcessor com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor=\