diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/ModelSchema.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/ModelSchema.java index 9d7a3c731..ec5e2e268 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/ModelSchema.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/ModelSchema.java @@ -44,4 +44,33 @@ public class ModelSchema { } } + public SchemaElement getElement(SchemaElementType elementType, String name) { + Optional element = Optional.empty(); + + switch (elementType) { + case ENTITY: + element = Optional.ofNullable(entity); + break; + case MODEL: + element = Optional.of(model); + break; + case METRIC: + element = metrics.stream().filter(e -> name.equals(e.getName())).findFirst(); + break; + case DIMENSION: + element = dimensions.stream().filter(e -> name.equals(e.getName())).findFirst(); + break; + case VALUE: + element = dimensionValues.stream().filter(e -> name.equals(e.getName())).findFirst(); + break; + default: + } + + if (element.isPresent()) { + return element.get(); + } else { + return null; + } + } + } diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/RelateSchemaElement.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/RelateSchemaElement.java index 99a497555..052cf53b4 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/RelateSchemaElement.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/RelateSchemaElement.java @@ -1,9 +1,15 @@ package com.tencent.supersonic.chat.api.pojo; +import lombok.AllArgsConstructor; +import lombok.Builder; import lombok.Data; +import lombok.NoArgsConstructor; @Data +@Builder +@NoArgsConstructor +@AllArgsConstructor public class RelateSchemaElement { private Long dimensionId; diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/ParseResp.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/ParseResp.java index 660f2c3da..62b3869b8 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/ParseResp.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/ParseResp.java @@ -1,21 +1,11 @@ package com.tencent.supersonic.chat.api.pojo.response; -import cn.hutool.core.collection.CollectionUtil; import com.google.common.collect.Lists; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import lombok.Data; -import lombok.Getter; -import lombok.Builder; -import lombok.NoArgsConstructor; -import lombok.AllArgsConstructor; - import java.util.List; @Data -@Getter -@Builder -@NoArgsConstructor -@AllArgsConstructor public class ParseResp { private Integer chatId; private String queryText; @@ -23,8 +13,7 @@ public class ParseResp { private ParseState state; private List selectedParses = Lists.newArrayList(); private List candidateParses = Lists.newArrayList(); - private List similarSolvedQuery; - private ParseTimeCostDO parseTimeCost; + private ParseTimeCostDO parseTimeCost = new ParseTimeCostDO(); public enum ParseState { COMPLETED, @@ -32,12 +21,4 @@ public class ParseResp { FAILED } - public List getSelectedParses() { - selectedParses = Lists.newArrayList(); - if (CollectionUtil.isNotEmpty(candidateParses)) { - selectedParses.addAll(candidateParses); - candidateParses.clear(); - } - return selectedParses; - } } diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/ParseTimeCostDO.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/ParseTimeCostDO.java index df18d34db..7049c8e5c 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/ParseTimeCostDO.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/ParseTimeCostDO.java @@ -4,6 +4,12 @@ import lombok.Data; @Data public class ParseTimeCostDO { - private Long parseTime; - private Long sqlTime; + + private long parseStartTime; + private long parseTime; + private long sqlTime; + + public ParseTimeCostDO() { + this.parseStartTime = System.currentTimeMillis(); + } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/rule/MetricCheckParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/rule/MetricCheckParser.java deleted file mode 100644 index fb9ffbd66..000000000 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/rule/MetricCheckParser.java +++ /dev/null @@ -1,100 +0,0 @@ -package com.tencent.supersonic.chat.parser.rule; - -import com.google.common.collect.Lists; -import com.tencent.supersonic.chat.api.component.SemanticInterpreter; -import com.tencent.supersonic.chat.api.component.SemanticParser; -import com.tencent.supersonic.chat.api.component.SemanticQuery; -import com.tencent.supersonic.chat.api.pojo.ChatContext; -import com.tencent.supersonic.chat.api.pojo.ModelSchema; -import com.tencent.supersonic.chat.api.pojo.QueryContext; -import com.tencent.supersonic.chat.api.pojo.RelateSchemaElement; -import com.tencent.supersonic.chat.api.pojo.SchemaElement; -import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; -import com.tencent.supersonic.chat.api.pojo.SchemaElementType; -import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.chat.query.rule.metric.MetricSemanticQuery; -import com.tencent.supersonic.chat.utils.ComponentFactory; -import org.apache.commons.collections.CollectionUtils; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; - -public class MetricCheckParser implements SemanticParser { - - @Override - public void parse(QueryContext queryContext, ChatContext chatContext) { - List semanticQueries = queryContext.getCandidateQueries(); - if (CollectionUtils.isEmpty(semanticQueries)) { - return; - } - semanticQueries.removeIf(this::removeQuery); - } - - private boolean removeQuery(SemanticQuery semanticQuery) { - if (semanticQuery instanceof MetricSemanticQuery) { - SemanticParseInfo parseInfo = semanticQuery.getParseInfo(); - List schemaElementMatches = parseInfo.getElementMatches(); - List elementMatchFiltered = - filterMetricElement(schemaElementMatches, parseInfo.getModelId()); - return 0 >= getMetricElementMatchCount(elementMatchFiltered); - } - return false; - } - - private List filterMetricElement(List elementMatches, Long modelId) { - List filterSchemaElementMatch = Lists.newArrayList(); - SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer(); - ModelSchema modelSchema = semanticInterpreter.getModelSchema(modelId, true); - Set metricElements = modelSchema.getMetrics(); - Map valueElementMatchMap = getValueElementMap(elementMatches); - Map metricMap = metricElements.stream() - .collect(Collectors.toMap(SchemaElement::getId, e -> e, (e1, e2) -> e2)); - for (SchemaElementMatch schemaElementMatch : elementMatches) { - if (!SchemaElementType.METRIC.equals(schemaElementMatch.getElement().getType())) { - filterSchemaElementMatch.add(schemaElementMatch); - continue; - } - SchemaElement metric = metricMap.get(schemaElementMatch.getElement().getId()); - List necessaryDimensionIds = getNecessaryDimensionIds(metric); - boolean flag = true; - for (Long necessaryDimensionId : necessaryDimensionIds) { - if (!valueElementMatchMap.containsKey(necessaryDimensionId)) { - flag = false; - break; - } - } - if (flag) { - filterSchemaElementMatch.add(schemaElementMatch); - } - } - return filterSchemaElementMatch; - } - - private Map getValueElementMap(List elementMatches) { - return elementMatches.stream() - .filter(elementMatch -> - SchemaElementType.VALUE.equals(elementMatch.getElement().getType())) - .collect(Collectors.toMap(elementMatch -> elementMatch.getElement().getId(), e -> e, (e1, e2) -> e1)); - } - - private long getMetricElementMatchCount(List elementMatches) { - return elementMatches.stream().filter(elementMatch -> - SchemaElementType.METRIC.equals(elementMatch.getElement().getType())) - .count(); - } - - private List getNecessaryDimensionIds(SchemaElement metric) { - if (metric == null) { - return Lists.newArrayList(); - } - List relateSchemaElements = metric.getRelateSchemaElements(); - if (CollectionUtils.isEmpty(relateSchemaElements)) { - return Lists.newArrayList(); - } - return relateSchemaElements.stream() - .filter(RelateSchemaElement::isNecessary).map(RelateSchemaElement::getDimensionId) - .collect(Collectors.toList()); - } - -} \ No newline at end of file diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/rule/RuleBasedParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/rule/RuleBasedParser.java index 6df80ed7e..d2e82cbb1 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/rule/RuleBasedParser.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/rule/RuleBasedParser.java @@ -19,7 +19,6 @@ public class RuleBasedParser implements SemanticParser { new QueryModeParser(), new ContextInheritParser(), new AgentCheckParser(), - new MetricCheckParser(), new TimeRangeParser(), new AggregateTypeParser() ); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/dataobject/CostType.java b/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/dataobject/CostType.java index 637bffb7f..e00b41dcf 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/dataobject/CostType.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/dataobject/CostType.java @@ -4,7 +4,8 @@ public enum CostType { MAPPER(1, "mapper"), PARSER(2, "parser"), QUERY(3, "query"), - PARSERRESPONDER(4, "responder"); + PARSERRESPONDER(4, "responder"), + POSTPROCESSOR(5, "postprocessor"); private Integer type; private String name; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/repository/impl/ChatQueryRepositoryImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/repository/impl/ChatQueryRepositoryImpl.java index 5146901b3..5c9b8212d 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/repository/impl/ChatQueryRepositoryImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/repository/impl/ChatQueryRepositoryImpl.java @@ -4,33 +4,31 @@ import com.github.pagehelper.PageHelper; import com.github.pagehelper.PageInfo; import com.tencent.supersonic.chat.api.pojo.ChatContext; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq; import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.response.ParseResp; +import com.tencent.supersonic.chat.api.pojo.response.QueryResp; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO; import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO; import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDOExample; import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDOExample.Criteria; -import com.tencent.supersonic.chat.api.pojo.response.QueryResp; -import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq; import com.tencent.supersonic.chat.persistence.mapper.ChatParseMapper; import com.tencent.supersonic.chat.persistence.mapper.ChatQueryDOMapper; import com.tencent.supersonic.chat.persistence.mapper.custom.ShowCaseCustomMapper; import com.tencent.supersonic.chat.persistence.repository.ChatQueryRepository; import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.PageUtils; - -import java.util.ArrayList; -import java.util.Comparator; -import java.util.List; -import java.util.stream.Collectors; - import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.BeanUtils; import org.springframework.context.annotation.Primary; import org.springframework.stereotype.Repository; import org.springframework.util.CollectionUtils; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.stream.Collectors; @Repository @Primary @@ -136,7 +134,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository { List candidateParses) { Long queryId = createChatParse(parseResult, chatCtx, queryReq); List chatParseDOList = new ArrayList<>(); - getChatParseDO(chatCtx, queryReq, queryId, 0, candidateParses, chatParseDOList); + getChatParseDO(chatCtx, queryReq, queryId, candidateParses, chatParseDOList); chatParseMapper.batchSaveParseInfo(chatParseDOList); return chatParseDOList; } @@ -148,11 +146,10 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository { } } - public void getChatParseDO(ChatContext chatCtx, QueryReq queryReq, Long queryId, int base, + public void getChatParseDO(ChatContext chatCtx, QueryReq queryReq, Long queryId, List parses, List chatParseDOList) { for (int i = 0; i < parses.size(); i++) { ChatParseDO chatParseDO = new ChatParseDO(); - parses.get(i).setId(base + i + 1); chatParseDO.setChatId(Long.valueOf(chatCtx.getChatId())); chatParseDO.setQuestionId(queryId); chatParseDO.setQueryText(queryReq.getQueryText()); @@ -161,7 +158,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository { if (i == 0) { chatParseDO.setIsCandidate(0); } - chatParseDO.setParseId(base + i + 1); + chatParseDO.setParseId(parses.get(i).getId()); chatParseDO.setCreateTime(new java.util.Date()); chatParseDO.setUserName(queryReq.getUser().getName()); chatParseDOList.add(chatParseDO); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/postprocessor/MetricCheckPostProcessor.java b/chat/core/src/main/java/com/tencent/supersonic/chat/postprocessor/MetricCheckPostProcessor.java new file mode 100644 index 000000000..5b07bc493 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/postprocessor/MetricCheckPostProcessor.java @@ -0,0 +1,202 @@ +package com.tencent.supersonic.chat.postprocessor; + +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import com.tencent.supersonic.chat.api.component.SemanticQuery; +import com.tencent.supersonic.chat.api.pojo.ModelSchema; +import com.tencent.supersonic.chat.api.pojo.QueryContext; +import com.tencent.supersonic.chat.api.pojo.RelateSchemaElement; +import com.tencent.supersonic.chat.api.pojo.SchemaElement; +import com.tencent.supersonic.chat.api.pojo.SchemaElementType; +import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.common.pojo.QueryType; +import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; +import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserRemoveHelper; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; +import com.tencent.supersonic.knowledge.service.SchemaService; +import org.apache.commons.lang3.StringUtils; +import org.springframework.util.CollectionUtils; +import java.util.Collection; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * MetricCheckPostProcessor used to verify whether the dimensions + * involved in the query in metric mode can drill down on the metric. + */ +public class MetricCheckPostProcessor implements PostProcessor { + + @Override + public void process(QueryContext queryContext) { + List semanticQueries = queryContext.getCandidateQueries(); + for (SemanticQuery semanticQuery : semanticQueries) { + SemanticParseInfo parseInfo = semanticQuery.getParseInfo(); + if (!QueryType.METRIC.equals(parseInfo.getQueryType())) { + continue; + } + SchemaService schemaService = ContextUtils.getBean(SchemaService.class); + ModelSchema modelSchema = schemaService.getModelSchema(parseInfo.getModelId()); + String correctSqlProcessed = processCorrectSql(parseInfo.getSqlInfo().getCorrectS2SQL(), modelSchema); + parseInfo.getSqlInfo().setCorrectS2SQL(correctSqlProcessed); + } + semanticQueries.removeIf(semanticQuery -> { + if (!QueryType.METRIC.equals(semanticQuery.getParseInfo().getQueryType())) { + return false; + } + String correctSql = semanticQuery.getParseInfo().getSqlInfo().getCorrectS2SQL(); + if (StringUtils.isBlank(correctSql)) { + return false; + } + return CollectionUtils.isEmpty(SqlParserSelectHelper.getAggregateFields(correctSql)); + }); + } + + public String processCorrectSql(String correctSql, ModelSchema modelSchema) { + List groupByFields = SqlParserSelectHelper.getGroupByFields(correctSql); + List metricFields = SqlParserSelectHelper.getAggregateFields(correctSql); + List whereFields = SqlParserSelectHelper.getWhereFields(correctSql); + List dimensionFields = getDimensionFields(groupByFields, whereFields); + if (CollectionUtils.isEmpty(metricFields) || StringUtils.isBlank(correctSql)) { + return correctSql; + } + Set metricToRemove = Sets.newHashSet(); + Set groupByToRemove = Sets.newHashSet(); + Set whereFieldsToRemove = Sets.newHashSet(); + for (String metricName : metricFields) { + SchemaElement metricElement = modelSchema.getElement(SchemaElementType.METRIC, metricName); + if (metricElement == null) { + metricToRemove.add(metricName); + } + if (!checkNecessaryDimension(metricElement, modelSchema, dimensionFields)) { + metricToRemove.add(metricName); + } + } + for (String dimensionName : whereFields) { + if (TimeDimensionEnum.getNameList().contains(dimensionName)) { + continue; + } + if (!checkInModelSchema(dimensionName, SchemaElementType.DIMENSION, modelSchema)) { + whereFieldsToRemove.add(dimensionName); + } + if (!checkDrillDownDimension(dimensionName, metricFields, modelSchema)) { + whereFieldsToRemove.add(dimensionName); + } + } + for (String dimensionName : groupByFields) { + if (TimeDimensionEnum.getNameList().contains(dimensionName)) { + continue; + } + if (!checkInModelSchema(dimensionName, SchemaElementType.DIMENSION, modelSchema)) { + groupByToRemove.add(dimensionName); + } + if (!checkDrillDownDimension(dimensionName, metricFields, modelSchema)) { + groupByToRemove.add(dimensionName); + } + } + return removeFieldInSql(correctSql, metricToRemove, groupByToRemove, whereFieldsToRemove); + } + + + /** + * To check whether the dimension bound to the metric exists, + * eg: metric like UV is calculated in a certain dimension, it cannot be used on other dimensions. + */ + private boolean checkNecessaryDimension(SchemaElement metric, ModelSchema modelSchema, + List dimensionFields) { + List necessaryDimensions = getNecessaryDimensionNames(metric, modelSchema); + if (CollectionUtils.isEmpty(necessaryDimensions)) { + return true; + } + for (String dimension : necessaryDimensions) { + if (!dimensionFields.contains(dimension)) { + return false; + } + } + return true; + } + + /** + * To check whether the dimension can drill down the metric, + * eg: some descriptive dimensions are not suitable as drill-down dimensions + */ + private boolean checkDrillDownDimension(String dimensionName, List metrics, + ModelSchema modelSchema) { + List metricElements = modelSchema.getMetrics().stream() + .filter(schemaElement -> metrics.contains(schemaElement.getName())) + .collect(Collectors.toList()); + if (CollectionUtils.isEmpty(metricElements)) { + return false; + } + List relateDimensions = metricElements.stream() + .filter(schemaElement -> !CollectionUtils.isEmpty(schemaElement.getRelateSchemaElements())) + .map(schemaElement -> schemaElement.getRelateSchemaElements().stream() + .map(RelateSchemaElement::getDimensionId).collect(Collectors.toList())) + .flatMap(Collection::stream) + .map(id -> convertDimensionIdToName(id, modelSchema)) + .filter(Objects::nonNull) + .collect(Collectors.toList()); + //if no metric has drill down dimension, return true + if (CollectionUtils.isEmpty(relateDimensions)) { + return true; + } + //if this dimension not in relate drill-down dimensions, return false + return relateDimensions.contains(dimensionName); + } + + private List getNecessaryDimensionNames(SchemaElement metric, ModelSchema modelSchema) { + List necessaryDimensionIds = getNecessaryDimensions(metric); + return necessaryDimensionIds.stream().map(id -> convertDimensionIdToName(id, modelSchema)) + .filter(Objects::nonNull).collect(Collectors.toList()); + } + + private List getNecessaryDimensions(SchemaElement metric) { + if (metric == null) { + return Lists.newArrayList(); + } + List relateSchemaElements = metric.getRelateSchemaElements(); + if (CollectionUtils.isEmpty(relateSchemaElements)) { + return Lists.newArrayList(); + } + return relateSchemaElements.stream() + .filter(RelateSchemaElement::isNecessary).map(RelateSchemaElement::getDimensionId) + .collect(Collectors.toList()); + } + + private List getDimensionFields(List groupByFields, List whereFields) { + List dimensionFields = Lists.newArrayList(); + if (!CollectionUtils.isEmpty(groupByFields)) { + dimensionFields.addAll(groupByFields); + } + if (!CollectionUtils.isEmpty(whereFields)) { + dimensionFields.addAll(whereFields); + } + return dimensionFields; + } + + private String convertDimensionIdToName(Long id, ModelSchema modelSchema) { + SchemaElement schemaElement = modelSchema.getElement(SchemaElementType.DIMENSION, id); + if (schemaElement == null) { + return null; + } + return schemaElement.getName(); + } + + private boolean checkInModelSchema(String name, SchemaElementType type, ModelSchema modelSchema) { + SchemaElement schemaElement = modelSchema.getElement(type, name); + return schemaElement != null; + } + + private static String removeFieldInSql(String sql, Set metricToRemove, + Set dimensionByToRemove, Set whereFieldsToRemove) { + sql = SqlParserRemoveHelper.removeWhereCondition(sql, whereFieldsToRemove); + sql = SqlParserRemoveHelper.removeSelect(sql, metricToRemove); + sql = SqlParserRemoveHelper.removeSelect(sql, dimensionByToRemove); + sql = SqlParserRemoveHelper.removeGroupBy(sql, dimensionByToRemove); + sql = SqlParserRemoveHelper.removeNumberCondition(sql); + return sql; + } + +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/postprocessor/ParseInfoUpdateProcessor.java b/chat/core/src/main/java/com/tencent/supersonic/chat/postprocessor/ParseInfoUpdateProcessor.java new file mode 100644 index 000000000..566a91d6d --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/postprocessor/ParseInfoUpdateProcessor.java @@ -0,0 +1,29 @@ +package com.tencent.supersonic.chat.postprocessor; + +import com.tencent.supersonic.chat.api.component.SemanticQuery; +import com.tencent.supersonic.chat.api.pojo.QueryContext; +import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.chat.service.ParseInfoService; +import com.tencent.supersonic.common.util.ContextUtils; +import org.springframework.util.CollectionUtils; +import java.util.List; +import java.util.stream.Collectors; + +/** + * update parse info from correct sql + */ +public class ParseInfoUpdateProcessor implements PostProcessor { + + @Override + public void process(QueryContext queryContext) { + List candidateQueries = queryContext.getCandidateQueries(); + if (CollectionUtils.isEmpty(candidateQueries)) { + return; + } + ParseInfoService parseInfoService = ContextUtils.getBean(ParseInfoService.class); + List candidateParses = candidateQueries.stream() + .map(SemanticQuery::getParseInfo).collect(Collectors.toList()); + candidateParses.forEach(parseInfoService::updateParseInfo); + } + +} \ No newline at end of file diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/postprocessor/PostProcessor.java b/chat/core/src/main/java/com/tencent/supersonic/chat/postprocessor/PostProcessor.java new file mode 100644 index 000000000..50536af09 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/postprocessor/PostProcessor.java @@ -0,0 +1,12 @@ +package com.tencent.supersonic.chat.postprocessor; +import com.tencent.supersonic.chat.api.pojo.QueryContext; + +/** + * A post processor do some logic after parser and corrector + */ + +public interface PostProcessor { + + void process(QueryContext queryContext); + +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/QueryRanker.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/QueryRanker.java index c4687d31d..f7d51546d 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/QueryRanker.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/QueryRanker.java @@ -4,13 +4,13 @@ import com.tencent.supersonic.chat.api.component.SemanticQuery; import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery; -import java.util.List; -import java.util.ArrayList; -import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Component; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; @Slf4j @Component @@ -30,6 +30,7 @@ public class QueryRanker { } else { selectedQueries = getTopCandidateQuery(candidateQueries); } + generateParseInfoId(selectedQueries); log.debug("pick after [{}]", selectedQueries); return selectedQueries; } @@ -48,6 +49,13 @@ public class QueryRanker { .collect(Collectors.toList()); } + private void generateParseInfoId(List semanticQueries) { + for (int i = 0; i < semanticQueries.size(); i++) { + SemanticQuery query = semanticQueries.get(i); + query.getParseInfo().setId(i + 1); + } + } + private boolean checkFullyInherited(SemanticQuery query) { SemanticParseInfo parseInfo = query.getParseInfo(); if (!(query instanceof RuleSemanticQuery)) { diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/EntityInfoParseResponder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/EntityInfoParseResponder.java index b45318b80..103760ad6 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/EntityInfoParseResponder.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/EntityInfoParseResponder.java @@ -1,27 +1,30 @@ package com.tencent.supersonic.chat.responder.parse; +import com.tencent.supersonic.chat.api.component.SemanticQuery; +import com.tencent.supersonic.chat.api.pojo.ChatContext; import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.response.EntityInfo; import com.tencent.supersonic.chat.api.pojo.response.ParseResp; -import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO; import com.tencent.supersonic.chat.query.QueryManager; import com.tencent.supersonic.chat.query.llm.interpret.MetricInterpretQuery; import com.tencent.supersonic.chat.service.SemanticService; import com.tencent.supersonic.common.util.ContextUtils; -import java.util.List; import org.springframework.util.CollectionUtils; +import java.util.List; +import java.util.stream.Collectors; public class EntityInfoParseResponder implements ParseResponder { @Override - public void fillResponse(ParseResp parseResp, QueryContext queryContext, - List chatParseDOS) { - List selectedParses = parseResp.getCandidateParses(); - if (CollectionUtils.isEmpty(selectedParses)) { + public void fillResponse(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) { + List semanticQueries = queryContext.getCandidateQueries(); + if (CollectionUtils.isEmpty(semanticQueries)) { return; } + List selectedParses = semanticQueries.stream().map(SemanticQuery::getParseInfo) + .collect(Collectors.toList()); QueryReq queryReq = queryContext.getRequest(); selectedParses.forEach(parseInfo -> { String queryMode = parseInfo.getQueryMode(); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/ParseRespBuildParseResponder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/ParseRespBuildParseResponder.java new file mode 100644 index 000000000..526d17b56 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/ParseRespBuildParseResponder.java @@ -0,0 +1,37 @@ +package com.tencent.supersonic.chat.responder.parse; + +import com.tencent.supersonic.chat.api.component.SemanticQuery; +import com.tencent.supersonic.chat.api.pojo.ChatContext; +import com.tencent.supersonic.chat.api.pojo.QueryContext; +import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.chat.api.pojo.request.QueryReq; +import com.tencent.supersonic.chat.api.pojo.response.ParseResp; +import com.tencent.supersonic.chat.service.ChatService; +import com.tencent.supersonic.common.util.ContextUtils; +import lombok.extern.slf4j.Slf4j; +import java.util.List; +import java.util.stream.Collectors; + +@Slf4j +public class ParseRespBuildParseResponder implements ParseResponder { + + @Override + public void fillResponse(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) { + QueryReq queryReq = queryContext.getRequest(); + parseResp.setChatId(queryReq.getChatId()); + parseResp.setQueryText(queryReq.getQueryText()); + List candidateQueries = queryContext.getCandidateQueries(); + if (candidateQueries.size() > 0) { + List candidateParses = candidateQueries.stream() + .map(SemanticQuery::getParseInfo).collect(Collectors.toList()); + parseResp.setCandidateParses(candidateParses); + parseResp.setState(ParseResp.ParseState.COMPLETED); + parseResp.setCandidateParses(candidateParses); + ChatService chatService = ContextUtils.getBean(ChatService.class); + chatService.batchAddParse(chatContext, queryReq, parseResp); + } else { + parseResp.setState(ParseResp.ParseState.FAILED); + } + } + +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/ParseResponder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/ParseResponder.java index 2953c236c..9ac17a2ec 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/ParseResponder.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/ParseResponder.java @@ -1,12 +1,11 @@ package com.tencent.supersonic.chat.responder.parse; +import com.tencent.supersonic.chat.api.pojo.ChatContext; import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.api.pojo.response.ParseResp; -import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO; -import java.util.List; public interface ParseResponder { - void fillResponse(ParseResp parseResp, QueryContext queryContext, List chatParseDOS); + void fillResponse(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext); } \ No newline at end of file diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/ParseTimeParseResponder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/ParseTimeParseResponder.java new file mode 100644 index 000000000..cc09a096f --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/ParseTimeParseResponder.java @@ -0,0 +1,20 @@ +package com.tencent.supersonic.chat.responder.parse; + +import com.tencent.supersonic.chat.api.pojo.ChatContext; +import com.tencent.supersonic.chat.api.pojo.QueryContext; +import com.tencent.supersonic.chat.api.pojo.response.ParseResp; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +public class ParseTimeParseResponder implements ParseResponder { + + + @Override + public void fillResponse(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) { + long parseStartTime = parseResp.getParseTimeCost().getParseStartTime(); + parseResp.getParseTimeCost().setParseTime( + System.currentTimeMillis() - parseStartTime - parseResp.getParseTimeCost().getSqlTime()); + } + + +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/QueryRankParseResponder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/QueryRankParseResponder.java new file mode 100644 index 000000000..53d52bbb8 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/QueryRankParseResponder.java @@ -0,0 +1,24 @@ +package com.tencent.supersonic.chat.responder.parse; + +import com.tencent.supersonic.chat.api.component.SemanticQuery; +import com.tencent.supersonic.chat.api.pojo.ChatContext; +import com.tencent.supersonic.chat.api.pojo.QueryContext; +import com.tencent.supersonic.chat.api.pojo.response.ParseResp; +import com.tencent.supersonic.chat.query.QueryRanker; +import com.tencent.supersonic.common.util.ContextUtils; + +import java.util.List; + +/** + * Rank queries by score. + */ +public class QueryRankParseResponder implements ParseResponder { + + @Override + public void fillResponse(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) { + List candidateQueries = queryContext.getCandidateQueries(); + QueryRanker queryRanker = ContextUtils.getBean(QueryRanker.class); + candidateQueries = queryRanker.rank(candidateQueries); + queryContext.setCandidateQueries(candidateQueries); + } +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/SqlInfoParseResponder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/SqlInfoParseResponder.java index 676e224ab..55437f92b 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/SqlInfoParseResponder.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/SqlInfoParseResponder.java @@ -1,50 +1,32 @@ package com.tencent.supersonic.chat.responder.parse; import com.tencent.supersonic.chat.api.component.SemanticQuery; +import com.tencent.supersonic.chat.api.pojo.ChatContext; import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.response.ParseResp; -import com.tencent.supersonic.chat.api.pojo.response.ParseTimeCostDO; -import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO; import com.tencent.supersonic.chat.query.QueryManager; -import com.tencent.supersonic.common.util.JsonUtil; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.function.Function; -import java.util.stream.Collectors; import org.apache.commons.lang3.StringUtils; import org.springframework.util.CollectionUtils; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; public class SqlInfoParseResponder implements ParseResponder { @Override - public void fillResponse(ParseResp parseResp, QueryContext queryContext, - List chatParseDOS) { + public void fillResponse(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) { QueryReq queryReq = queryContext.getRequest(); - Long startTime = System.currentTimeMillis(); - addSqlInfo(queryReq, parseResp.getCandidateParses()); - parseResp.setParseTimeCost(new ParseTimeCostDO()); - parseResp.getParseTimeCost().setSqlTime(System.currentTimeMillis() - startTime); - if (!CollectionUtils.isEmpty(chatParseDOS)) { - Map chatParseDOMap = chatParseDOS.stream() - .collect(Collectors.toMap(ChatParseDO::getParseId, - Function.identity(), (oldValue, newValue) -> newValue)); - updateParseInfo(chatParseDOMap, parseResp.getCandidateParses()); - } - } - - private void updateParseInfo(Map chatParseDOMap, List parseInfos) { - if (CollectionUtils.isEmpty(parseInfos)) { + List semanticQueries = queryContext.getCandidateQueries(); + if (CollectionUtils.isEmpty(semanticQueries)) { return; } - for (SemanticParseInfo parseInfo : parseInfos) { - ChatParseDO chatParseDO = chatParseDOMap.get(parseInfo.getId()); - if (chatParseDO != null) { - chatParseDO.setParseInfo(JsonUtil.toString(parseInfo)); - } - } + List selectedParses = semanticQueries.stream().map(SemanticQuery::getParseInfo) + .collect(Collectors.toList()); + long startTime = System.currentTimeMillis(); + addSqlInfo(queryReq, selectedParses); + parseResp.getParseTimeCost().setSqlTime(System.currentTimeMillis() - startTime); } private void addSqlInfo(QueryReq queryReq, List semanticParseInfos) { diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java index 99a9b7ab0..eadbd15b1 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java @@ -26,14 +26,13 @@ import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO; import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO; import com.tencent.supersonic.chat.persistence.dataobject.CostType; import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO; +import com.tencent.supersonic.chat.postprocessor.PostProcessor; import com.tencent.supersonic.chat.query.QueryManager; -import com.tencent.supersonic.chat.query.QueryRanker; import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery; import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery; import com.tencent.supersonic.chat.responder.execute.ExecuteResponder; import com.tencent.supersonic.chat.responder.parse.ParseResponder; import com.tencent.supersonic.chat.service.ChatService; -import com.tencent.supersonic.chat.service.ParseInfoService; import com.tencent.supersonic.chat.service.QueryService; import com.tencent.supersonic.chat.service.SemanticService; import com.tencent.supersonic.chat.service.StatisticsService; @@ -61,15 +60,6 @@ import com.tencent.supersonic.knowledge.utils.HanlpHelper; import com.tencent.supersonic.knowledge.utils.NatureHelper; import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.PriorityQueue; -import java.util.Set; -import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.LongValue; @@ -85,7 +75,6 @@ import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals; import net.sf.jsqlparser.schema.Column; import org.apache.calcite.sql.parser.SqlParseException; import org.apache.commons.collections.CollectionUtils; -import org.apache.commons.compress.utils.Lists; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; @@ -93,6 +82,16 @@ import org.springframework.context.annotation.Primary; import org.springframework.stereotype.Component; import org.springframework.stereotype.Service; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.PriorityQueue; +import java.util.Set; +import java.util.stream.Collectors; + @Service @Component("chatQueryService") @Primary @@ -105,23 +104,20 @@ public class QueryServiceImpl implements QueryService { private StatisticsService statisticsService; @Autowired private SolvedQueryManager solvedQueryManager; - @Autowired - private ParseInfoService parseInfoService; - @Autowired - private QueryRanker queryRanker; @Value("${time.threshold: 100}") private Integer timeThreshold; private List schemaMappers = ComponentFactory.getSchemaMappers(); private List semanticParsers = ComponentFactory.getSemanticParsers(); + private List postProcessors = ComponentFactory.getPostProcessors(); private List parseResponders = ComponentFactory.getParseResponders(); private List executeResponders = ComponentFactory.getExecuteResponders(); private List semanticCorrectors = ComponentFactory.getSqlCorrections(); @Override public ParseResp performParsing(QueryReq queryReq) { - Long parseTime = System.currentTimeMillis(); + ParseResp parseResult = new ParseResp(); //1. build queryContext and chatContext QueryContext queryCtx = new QueryContext(queryReq); // in order to support multi-turn conversation, chat context is needed @@ -129,16 +125,16 @@ public class QueryServiceImpl implements QueryService { List timeCostDOList = new ArrayList<>(); //2. mapper - schemaMappers.stream().forEach(mapper -> { - Long startTime = System.currentTimeMillis(); + schemaMappers.forEach(mapper -> { + long startTime = System.currentTimeMillis(); mapper.map(queryCtx); timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime)) .interfaceName(mapper.getClass().getSimpleName()).type(CostType.MAPPER.getType()).build()); }); //3. parser - semanticParsers.stream().forEach(parser -> { - Long startTime = System.currentTimeMillis(); + semanticParsers.forEach(parser -> { + long startTime = System.currentTimeMillis(); parser.parse(queryCtx, chatCtx); timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime)) .interfaceName(parser.getClass().getSimpleName()).type(CostType.PARSER.getType()).build()); @@ -153,46 +149,34 @@ public class QueryServiceImpl implements QueryService { if (semanticQuery instanceof RuleSemanticQuery) { continue; } - semanticCorrectors.stream().forEach(correction -> { + semanticCorrectors.forEach(correction -> { correction.correct(queryReq, semanticQuery.getParseInfo()); }); } } - //5. generate parsing results. - ParseResp parseResult; - List chatParseDOS = Lists.newArrayList(); - if (candidateQueries.size() > 0) { - candidateQueries = queryRanker.rank(candidateQueries); - List candidateParses = candidateQueries.stream() - .map(SemanticQuery::getParseInfo).collect(Collectors.toList()); - candidateParses.forEach(parseInfo -> parseInfoService.updateParseInfo(parseInfo)); - parseResult = ParseResp.builder().chatId(queryReq.getChatId()).queryText(queryReq.getQueryText()) - .state(ParseResp.ParseState.COMPLETED).candidateParses(candidateParses) - .build(); - chatParseDOS = chatService.batchAddParse(chatCtx, queryReq, parseResult); - } else { - parseResult = ParseResp.builder() - .chatId(queryReq.getChatId()) - .queryText(queryReq.getQueryText()) - .state(ParseResp.ParseState.FAILED) - .build(); - } + //5. postProcessor + postProcessors.forEach(postProcessor -> { + long startTime = System.currentTimeMillis(); + postProcessor.process(queryCtx); + timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime)) + .interfaceName(postProcessor.getClass().getSimpleName()) + .type(CostType.POSTPROCESSOR.getType()).build()); + }); + //6. responders - for (ParseResponder parseResponder : parseResponders) { - Long startTime = System.currentTimeMillis(); - parseResponder.fillResponse(parseResult, queryCtx, chatParseDOS); + parseResponders.forEach(parseResponder -> { + long startTime = System.currentTimeMillis(); + parseResponder.fillResponse(parseResult, queryCtx, chatCtx); timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime)) .interfaceName(parseResponder.getClass().getSimpleName()) .type(CostType.PARSERRESPONDER.getType()).build()); - } + }); + if (Objects.nonNull(parseResult.getQueryId()) && timeCostDOList.size() > 0) { - saveInfo(timeCostDOList, queryReq.getQueryText(), parseResult.getQueryId(), + saveTimeCostInfo(timeCostDOList, queryReq.getQueryText(), parseResult.getQueryId(), queryReq.getUser().getName(), queryReq.getChatId().longValue()); } - chatService.updateChatParse(chatParseDOS); - parseResult.getParseTimeCost().setParseTime( - System.currentTimeMillis() - parseTime - parseResult.getParseTimeCost().getSqlTime()); return parseResult; } @@ -220,7 +204,7 @@ public class QueryServiceImpl implements QueryService { timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime)) .interfaceName(semanticQuery.getClass().getSimpleName()).type(CostType.QUERY.getType()).build()); queryResult.setQueryTimeCost(timeCostDOList.get(0).getCost().longValue()); - saveInfo(timeCostDOList, queryReq.getQueryText(), queryReq.getQueryId(), + saveTimeCostInfo(timeCostDOList, queryReq.getQueryText(), queryReq.getQueryId(), queryReq.getUser().getName(), queryReq.getChatId().longValue()); queryResult.setChatContext(parseInfo); // update chat context after a successful semantic query @@ -242,7 +226,7 @@ public class QueryServiceImpl implements QueryService { } // save time cost data - private void saveInfo(List timeCostDOList, + private void saveTimeCostInfo(List timeCostDOList, String queryText, Long queryId, String userName, Long chatId) { List list = timeCostDOList.stream() diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/utils/ComponentFactory.java b/chat/core/src/main/java/com/tencent/supersonic/chat/utils/ComponentFactory.java index e4538827a..296932766 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/utils/ComponentFactory.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/utils/ComponentFactory.java @@ -1,19 +1,18 @@ package com.tencent.supersonic.chat.utils; import com.tencent.supersonic.chat.api.component.SchemaMapper; +import com.tencent.supersonic.chat.api.component.SemanticCorrector; import com.tencent.supersonic.chat.api.component.SemanticInterpreter; import com.tencent.supersonic.chat.api.component.SemanticParser; - -import com.tencent.supersonic.chat.api.component.SemanticCorrector; -import java.util.ArrayList; -import java.util.List; -import java.util.Objects; - import com.tencent.supersonic.chat.parser.llm.s2sql.ModelResolver; +import com.tencent.supersonic.chat.postprocessor.PostProcessor; import com.tencent.supersonic.chat.responder.execute.ExecuteResponder; import com.tencent.supersonic.chat.responder.parse.ParseResponder; import org.apache.commons.collections.CollectionUtils; import org.springframework.core.io.support.SpringFactoriesLoader; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; public class ComponentFactory { @@ -21,6 +20,7 @@ public class ComponentFactory { private static List semanticParsers = new ArrayList<>(); private static List s2SQLCorrections = new ArrayList<>(); private static SemanticInterpreter semanticInterpreter; + private static List postProcessors = new ArrayList<>(); private static List parseResponders = new ArrayList<>(); private static List executeResponders = new ArrayList<>(); private static ModelResolver modelResolver; @@ -37,6 +37,10 @@ public class ComponentFactory { s2SQLCorrections) : s2SQLCorrections; } + public static List getPostProcessors() { + return CollectionUtils.isEmpty(postProcessors) ? init(PostProcessor.class, postProcessors) : postProcessors; + } + public static List getParseResponders() { return CollectionUtils.isEmpty(parseResponders) ? init(ParseResponder.class, parseResponders) : parseResponders; } diff --git a/chat/core/src/test/java/com/tencent/supersonic/chat/postprocessor/MetricCheckPostProcessorTest.java b/chat/core/src/test/java/com/tencent/supersonic/chat/postprocessor/MetricCheckPostProcessorTest.java new file mode 100644 index 000000000..35c979858 --- /dev/null +++ b/chat/core/src/test/java/com/tencent/supersonic/chat/postprocessor/MetricCheckPostProcessorTest.java @@ -0,0 +1,124 @@ +package com.tencent.supersonic.chat.postprocessor; + +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import com.tencent.supersonic.chat.api.pojo.ModelSchema; +import com.tencent.supersonic.chat.api.pojo.RelateSchemaElement; +import com.tencent.supersonic.chat.api.pojo.SchemaElement; +import com.tencent.supersonic.chat.api.pojo.SchemaElementType; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import java.util.List; +import java.util.Set; + +class MetricCheckPostProcessorTest { + + @Test + void testProcessCorrectSql_necessaryDimension_groupBy() { + MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor(); + String correctSql = "select 用户名, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 用户名"; + String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql, mockModelSchema()); + String expectedProcessedSql = "SELECT 用户名, sum(访问次数) FROM 超音数 GROUP BY 用户名"; + Assertions.assertEquals(expectedProcessedSql, actualProcessedSql); + } + + @Test + void testProcessCorrectSql_necessaryDimension_where() { + MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor(); + String correctSql = "select 用户名, sum(访问次数), count(distinct 访问用户数) from 超音数 where 部门 = 'HR' group by 用户名"; + String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql, mockModelSchema()); + String expectedProcessedSql = "SELECT 用户名, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 " + + "WHERE 部门 = 'HR' GROUP BY 用户名"; + Assertions.assertEquals(expectedProcessedSql, actualProcessedSql); + } + + @Test + void testProcessCorrectSql_dimensionNotDrillDown_groupBy() { + MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor(); + String correctSql = "select 页面, 部门, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 页面, 部门"; + String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql, mockModelSchema()); + String expectedProcessedSql = "SELECT 部门, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 GROUP BY 部门"; + Assertions.assertEquals(expectedProcessedSql, actualProcessedSql); + } + + @Test + void testProcessCorrectSql_dimensionNotDrillDown_where() { + MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor(); + String correctSql = "select 部门, sum(访问次数), count(distinct 访问用户数) from 超音数 where 页面 = 'P1' group by 部门"; + String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql, mockModelSchema()); + String expectedProcessedSql = "SELECT 部门, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 GROUP BY 部门"; + Assertions.assertEquals(expectedProcessedSql, actualProcessedSql); + } + + @Test + void testProcessCorrectSql_dimensionNotDrillDown_necessaryDimension() { + MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor(); + String correctSql = "select 页面, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 页面"; + String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql, mockModelSchema()); + String expectedProcessedSql = "SELECT sum(访问次数) FROM 超音数"; + Assertions.assertEquals(expectedProcessedSql, actualProcessedSql); + } + + @Test + void testProcessCorrectSql_dimensionDrillDown() { + MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor(); + String correctSql = "select 用户名, 部门, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 用户名, 部门"; + String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql, mockModelSchema()); + String expectedProcessedSql = "SELECT 用户名, 部门, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 GROUP BY 用户名, 部门"; + Assertions.assertEquals(expectedProcessedSql, actualProcessedSql); + } + + @Test + void testProcessCorrectSql_noDrillDownDimensionSetting() { + MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor(); + String correctSql = "select 页面, 用户名, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 页面, 用户名"; + String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql, + mockModelSchemaNoDimensionSetting()); + String expectedProcessedSql = "SELECT 页面, 用户名, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 GROUP BY 页面, 用户名"; + Assertions.assertEquals(expectedProcessedSql, actualProcessedSql); + } + + /** + * 访问次数 drill down dimension is 用户名 and 部门 + * 访问用户数 drill down dimension is 部门, and 部门 is necessary, 部门 need in select and group by or where expressions + */ + private ModelSchema mockModelSchema() { + ModelSchema modelSchema = new ModelSchema(); + Set metrics = Sets.newHashSet( + mockElement(1L, "访问次数", SchemaElementType.METRIC, + Lists.newArrayList(RelateSchemaElement.builder().dimensionId(2L).isNecessary(false).build(), + RelateSchemaElement.builder().dimensionId(1L).isNecessary(false).build())), + mockElement(2L, "访问用户数", SchemaElementType.METRIC, + Lists.newArrayList(RelateSchemaElement.builder().dimensionId(2L).isNecessary(true).build())) + ); + modelSchema.setMetrics(metrics); + modelSchema.setDimensions(mockDimensions()); + return modelSchema; + } + + private ModelSchema mockModelSchemaNoDimensionSetting() { + ModelSchema modelSchema = new ModelSchema(); + Set metrics = Sets.newHashSet( + mockElement(1L, "访问次数", SchemaElementType.METRIC, Lists.newArrayList()), + mockElement(2L, "访问用户数", SchemaElementType.METRIC, Lists.newArrayList()) + ); + modelSchema.setMetrics(metrics); + modelSchema.setDimensions(mockDimensions()); + return modelSchema; + } + + private Set mockDimensions() { + return Sets.newHashSet( + mockElement(1L, "用户名", SchemaElementType.DIMENSION, Lists.newArrayList()), + mockElement(2L, "部门", SchemaElementType.DIMENSION, Lists.newArrayList()), + mockElement(3L, "页面", SchemaElementType.DIMENSION, Lists.newArrayList()) + ); + } + + private SchemaElement mockElement(Long id, String name, SchemaElementType type, + List relateSchemaElements) { + return SchemaElement.builder().id(id).name(name).type(type) + .relateSchemaElements(relateSchemaElements).build(); + } + +} \ No newline at end of file diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserRemoveHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserRemoveHelper.java index 5de1ad2b9..debb7a9b8 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserRemoveHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserRemoveHelper.java @@ -1,30 +1,37 @@ package com.tencent.supersonic.common.util.jsqlparser; -import java.util.List; -import java.util.Objects; -import java.util.Set; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.expression.BinaryExpression; import net.sf.jsqlparser.expression.Expression; -import net.sf.jsqlparser.expression.Parenthesis; import net.sf.jsqlparser.expression.LongValue; +import net.sf.jsqlparser.expression.Parenthesis; import net.sf.jsqlparser.expression.operators.conditional.AndExpression; import net.sf.jsqlparser.expression.operators.conditional.OrExpression; import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; import net.sf.jsqlparser.expression.operators.relational.EqualsTo; -import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo; +import net.sf.jsqlparser.expression.operators.relational.ExpressionList; import net.sf.jsqlparser.expression.operators.relational.GreaterThan; import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals; import net.sf.jsqlparser.expression.operators.relational.InExpression; +import net.sf.jsqlparser.expression.operators.relational.LikeExpression; import net.sf.jsqlparser.expression.operators.relational.MinorThan; import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals; -import net.sf.jsqlparser.expression.operators.relational.LikeExpression; +import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo; import net.sf.jsqlparser.parser.CCJSqlParserUtil; +import net.sf.jsqlparser.schema.Column; +import net.sf.jsqlparser.statement.select.GroupByElement; import net.sf.jsqlparser.statement.select.PlainSelect; import net.sf.jsqlparser.statement.select.Select; import net.sf.jsqlparser.statement.select.SelectBody; +import net.sf.jsqlparser.statement.select.SelectExpressionItem; +import net.sf.jsqlparser.statement.select.SelectItem; import net.sf.jsqlparser.statement.select.SelectVisitorAdapter; +import org.springframework.util.CollectionUtils; + +import java.util.List; +import java.util.Objects; +import java.util.Set; /** * Sql Parser remove Helper @@ -56,6 +63,9 @@ public class SqlParserRemoveHelper { } public static String removeNumberCondition(String sql) { Select selectStatement = SqlParserSelectHelper.getSelect(sql); + if (selectStatement == null) { + return sql; + } SelectBody selectBody = selectStatement.getSelectBody(); if (!(selectBody instanceof PlainSelect)) { @@ -191,6 +201,55 @@ public class SqlParserRemoveHelper { return selectStatement.toString(); } + + public static String removeSelect(String sql, Set fields) { + Select selectStatement = SqlParserSelectHelper.getSelect(sql); + if (selectStatement == null) { + return sql; + } + SelectBody selectBody = selectStatement.getSelectBody(); + if (!(selectBody instanceof PlainSelect)) { + return sql; + } + List selectItems = ((PlainSelect) selectBody).getSelectItems(); + selectItems.removeIf(selectItem -> { + if (selectItem instanceof SelectExpressionItem) { + SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem; + String columnName = SqlParserSelectHelper.getColumnName(selectExpressionItem.getExpression()); + return fields.contains(columnName); + } + return false; + }); + return selectStatement.toString(); + } + + public static String removeGroupBy(String sql, Set fields) { + Select selectStatement = SqlParserSelectHelper.getSelect(sql); + if (selectStatement == null) { + return sql; + } + SelectBody selectBody = selectStatement.getSelectBody(); + if (!(selectBody instanceof PlainSelect)) { + return sql; + } + GroupByElement groupByElement = ((PlainSelect) selectBody).getGroupBy(); + if (groupByElement == null) { + return sql; + } + ExpressionList groupByExpressionList = groupByElement.getGroupByExpressionList(); + groupByExpressionList.getExpressions().removeIf(expression -> { + if (expression instanceof Column) { + Column column = (Column) expression; + return fields.contains(column.getColumnName()); + } + return false; + }); + if (CollectionUtils.isEmpty(groupByExpressionList.getExpressions())) { + ((PlainSelect) selectBody).setGroupByElement(null); + } + return selectStatement.toString(); + } + private static Expression filteredWhereExpression(Expression where) { if (Objects.isNull(where)) { return null; diff --git a/launchers/standalone/src/main/resources/META-INF/spring.factories b/launchers/standalone/src/main/resources/META-INF/spring.factories index 66e211d18..98a40f3b6 100644 --- a/launchers/standalone/src/main/resources/META-INF/spring.factories +++ b/launchers/standalone/src/main/resources/META-INF/spring.factories @@ -31,9 +31,16 @@ com.tencent.supersonic.auth.authentication.interceptor.AuthenticationInterceptor com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\ com.tencent.supersonic.auth.authentication.adaptor.DefaultUserAdaptor +com.tencent.supersonic.chat.postprocessor.PostProcessor=\ + com.tencent.supersonic.chat.postprocessor.MetricCheckPostProcessor, \ + com.tencent.supersonic.chat.postprocessor.ParseInfoUpdateProcessor + com.tencent.supersonic.chat.responder.parse.ParseResponder=\ + com.tencent.supersonic.chat.responder.parse.QueryRankParseResponder, \ com.tencent.supersonic.chat.responder.parse.EntityInfoParseResponder, \ - com.tencent.supersonic.chat.responder.parse.SqlInfoParseResponder + com.tencent.supersonic.chat.responder.parse.SqlInfoParseResponder, \ + com.tencent.supersonic.chat.responder.parse.ParseTimeParseResponder, \ + com.tencent.supersonic.chat.responder.parse.ParseRespBuildParseResponder com.tencent.supersonic.chat.responder.execute.ExecuteResponder=\ com.tencent.supersonic.chat.responder.execute.EntityInfoExecuteResponder, \ diff --git a/launchers/standalone/src/test/resources/META-INF/spring.factories b/launchers/standalone/src/test/resources/META-INF/spring.factories index 12d09fbb4..98a40f3b6 100644 --- a/launchers/standalone/src/test/resources/META-INF/spring.factories +++ b/launchers/standalone/src/test/resources/META-INF/spring.factories @@ -1,19 +1,47 @@ com.tencent.supersonic.chat.api.component.SchemaMapper=\ - com.tencent.supersonic.chat.mapper.HanlpDictMapper + com.tencent.supersonic.chat.mapper.EmbeddingMapper, \ + com.tencent.supersonic.chat.mapper.HanlpDictMapper, \ + com.tencent.supersonic.chat.mapper.FuzzyNameMapper, \ + com.tencent.supersonic.chat.mapper.QueryFilterMapper, \ + com.tencent.supersonic.chat.mapper.EntityMapper com.tencent.supersonic.chat.api.component.SemanticParser=\ com.tencent.supersonic.chat.parser.rule.RuleBasedParser, \ - com.tencent.supersonic.chat.parser.llm.interpret.MetricInterpretParser -# com.tencent.supersonic.chat.parser.llm.DSLQueryFunction + com.tencent.supersonic.chat.parser.llm.s2sql.LLMS2SQLParser, \ + com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingBasedParser, \ + com.tencent.supersonic.chat.parser.plugin.function.FunctionBasedParser, \ + com.tencent.supersonic.chat.parser.QueryTypeParser -com.tencent.supersonic.chat.api.component.QueryProcessor=\ - com.tencent.supersonic.chat.application.processor.SemanticQueryProcessor +com.tencent.supersonic.chat.api.component.SemanticCorrector=\ + com.tencent.supersonic.chat.corrector.SchemaCorrector, \ + com.tencent.supersonic.chat.corrector.SelectCorrector, \ + com.tencent.supersonic.chat.corrector.WhereCorrector, \ + com.tencent.supersonic.chat.corrector.GroupByCorrector, \ + com.tencent.supersonic.chat.corrector.HavingCorrector com.tencent.supersonic.chat.api.component.SemanticInterpreter=\ com.tencent.supersonic.knowledge.semantic.LocalSemanticInterpreter -com.tencent.supersonic.chat.application.query.DomainResolver=\ - com.tencent.supersonic.chat.application.query.HeuristicDomainResolver +com.tencent.supersonic.chat.parser.llm.s2sql.ModelResolver=\ + com.tencent.supersonic.chat.parser.llm.s2sql.HeuristicModelResolver com.tencent.supersonic.auth.authentication.interceptor.AuthenticationInterceptor=\ com.tencent.supersonic.auth.authentication.interceptor.DefaultAuthenticationInterceptor + +com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\ + com.tencent.supersonic.auth.authentication.adaptor.DefaultUserAdaptor + +com.tencent.supersonic.chat.postprocessor.PostProcessor=\ + com.tencent.supersonic.chat.postprocessor.MetricCheckPostProcessor, \ + com.tencent.supersonic.chat.postprocessor.ParseInfoUpdateProcessor + +com.tencent.supersonic.chat.responder.parse.ParseResponder=\ + com.tencent.supersonic.chat.responder.parse.QueryRankParseResponder, \ + com.tencent.supersonic.chat.responder.parse.EntityInfoParseResponder, \ + com.tencent.supersonic.chat.responder.parse.SqlInfoParseResponder, \ + com.tencent.supersonic.chat.responder.parse.ParseTimeParseResponder, \ + com.tencent.supersonic.chat.responder.parse.ParseRespBuildParseResponder + +com.tencent.supersonic.chat.responder.execute.ExecuteResponder=\ + com.tencent.supersonic.chat.responder.execute.EntityInfoExecuteResponder, \ + com.tencent.supersonic.chat.responder.execute.SimilarMetricExecuteResponder \ No newline at end of file