mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-14 13:47:09 +00:00
[improvement][headless&chat]Move EntityInfoProcessor from chat to headless module and optimize code.
[improvement][headless&chat]Move `EntityInfoProcessor` from `chat` to `headless` module and optimize code.
This commit is contained in:
@@ -478,8 +478,7 @@ public class S2SemanticLayerService implements SemanticLayerService {
|
||||
SemanticQueryResp queryResultWithColumns =
|
||||
getQueryResultWithSchemaResp(entityInfo, dataSetSchema, user);
|
||||
if (queryResultWithColumns != null) {
|
||||
if (!org.springframework.util.CollectionUtils.isEmpty(queryResultWithColumns.getResultList())
|
||||
&& queryResultWithColumns.getResultList().size() > 0) {
|
||||
if (!CollectionUtils.isEmpty(queryResultWithColumns.getResultList())) {
|
||||
Map<String, Object> result = queryResultWithColumns.getResultList().get(0);
|
||||
for (Map.Entry<String, Object> entry : result.entrySet()) {
|
||||
String entryKey = getEntryKey(entry);
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
package com.tencent.supersonic.headless.server.processor;
|
||||
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.EntityInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
||||
|
||||
/**
|
||||
* EntityInfoProcessor fills core attributes of an entity so that
|
||||
* users get to know which entity is parsed out.
|
||||
*/
|
||||
public class EntityInfoProcessor implements ResultProcessor {
|
||||
|
||||
@Override
|
||||
public void process(ParseResp parseResp, ChatQueryContext chatQueryContext) {
|
||||
parseResp.getSelectedParses().forEach(parseInfo -> {
|
||||
String queryMode = parseInfo.getQueryMode();
|
||||
if (!QueryManager.isTagQuery(queryMode) && !QueryManager.isMetricQuery(queryMode)) {
|
||||
return;
|
||||
}
|
||||
|
||||
SemanticLayerService semanticService = ContextUtils.getBean(SemanticLayerService.class);
|
||||
DataSetSchema dataSetSchema = semanticService.getDataSetSchema(parseInfo.getDataSetId());
|
||||
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, chatQueryContext.getUser());
|
||||
parseInfo.setEntityInfo(entityInfo);
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -8,21 +8,19 @@ import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||
import com.tencent.supersonic.headless.server.service.SchemaService;
|
||||
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
@@ -40,94 +38,79 @@ public class ParseInfoProcessor implements ResultProcessor {
|
||||
|
||||
@Override
|
||||
public void process(ParseResp parseResp, ChatQueryContext chatQueryContext) {
|
||||
List<SemanticQuery> candidateQueries = chatQueryContext.getCandidateQueries();
|
||||
if (CollectionUtils.isEmpty(candidateQueries)) {
|
||||
return;
|
||||
}
|
||||
List<SemanticParseInfo> candidateParses = candidateQueries.stream()
|
||||
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
|
||||
candidateParses.forEach(this::updateParseInfo);
|
||||
parseResp.getSelectedParses().forEach(this::updateParseInfo);
|
||||
}
|
||||
|
||||
public void updateParseInfo(SemanticParseInfo parseInfo) {
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
String correctS2SQL = sqlInfo.getCorrectedS2SQL();
|
||||
if (StringUtils.isBlank(correctS2SQL)) {
|
||||
String s2SQL = sqlInfo.getCorrectedS2SQL();
|
||||
if (StringUtils.isBlank(s2SQL)) {
|
||||
return;
|
||||
}
|
||||
List<FieldExpression> expressions = SqlSelectHelper.getFilterExpression(correctS2SQL);
|
||||
List<FieldExpression> expressions = SqlSelectHelper.getFilterExpression(s2SQL);
|
||||
|
||||
//set dataInfo
|
||||
//extract date filter from S2SQL
|
||||
try {
|
||||
if (!org.apache.commons.collections.CollectionUtils.isEmpty(expressions)) {
|
||||
DateConf dateInfo = getDateInfo(expressions);
|
||||
if (dateInfo != null && parseInfo.getDateInfo() == null) {
|
||||
parseInfo.setDateInfo(dateInfo);
|
||||
}
|
||||
if (parseInfo.getDateInfo() == null && !CollectionUtils.isEmpty(expressions)) {
|
||||
parseInfo.setDateInfo(extractDateFilter(expressions));
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("set dateInfo error :", e);
|
||||
log.error("failed to extract date range:", e);
|
||||
}
|
||||
|
||||
if (correctS2SQL.equals(sqlInfo.getParsedS2SQL())) {
|
||||
return;
|
||||
}
|
||||
//set filter
|
||||
//extract dimension filters from S2SQL
|
||||
Long dataSetId = parseInfo.getDataSetId();
|
||||
SemanticLayerService semanticLayerService = ContextUtils.getBean(SemanticLayerService.class);
|
||||
DataSetSchema dsSchema = semanticLayerService.getDataSetSchema(dataSetId);
|
||||
|
||||
try {
|
||||
Map<String, SchemaElement> fieldNameToElement = getNameToElement(dataSetId);
|
||||
List<QueryFilter> result = getDimensionFilter(fieldNameToElement, expressions);
|
||||
parseInfo.getDimensionFilters().addAll(result);
|
||||
Map<String, SchemaElement> fieldNameToElement = getNameToElement(dsSchema);
|
||||
parseInfo.getDimensionFilters().addAll(extractDimensionFilter(fieldNameToElement, expressions));
|
||||
} catch (Exception e) {
|
||||
log.error("set dimensionFilter error :", e);
|
||||
log.error("failed to extract dimension filters:", e);
|
||||
}
|
||||
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
if (Objects.isNull(semanticSchema)) {
|
||||
return;
|
||||
}
|
||||
List<String> allFields = getFieldsExceptDate(SqlSelectHelper.getAllFields(sqlInfo.getCorrectedS2SQL()));
|
||||
Set<SchemaElement> metrics = getElements(dataSetId, allFields, semanticSchema.getMetrics());
|
||||
//extract metrics from S2SQL
|
||||
List<String> allFields = filterDateField(SqlSelectHelper.getAllSelectFields(s2SQL));
|
||||
Set<SchemaElement> metrics = matchSchemaElements(allFields, dsSchema.getMetrics());
|
||||
parseInfo.setMetrics(metrics);
|
||||
|
||||
//extract dimensions from S2SQL
|
||||
if (QueryType.METRIC.equals(parseInfo.getQueryType())) {
|
||||
List<String> groupByFields = SqlSelectHelper.getGroupByFields(sqlInfo.getCorrectedS2SQL());
|
||||
List<String> groupByDimensions = getFieldsExceptDate(groupByFields);
|
||||
parseInfo.setDimensions(getElements(dataSetId, groupByDimensions, semanticSchema.getDimensions()));
|
||||
List<String> groupByFields = SqlSelectHelper.getGroupByFields(s2SQL);
|
||||
List<String> groupByDimensions = filterDateField(groupByFields);
|
||||
parseInfo.setDimensions(matchSchemaElements(groupByDimensions, dsSchema.getDimensions()));
|
||||
} else if (QueryType.DETAIL.equals(parseInfo.getQueryType())) {
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getCorrectedS2SQL());
|
||||
List<String> selectDimensions = getFieldsExceptDate(selectFields);
|
||||
parseInfo.setDimensions(getElements(dataSetId, selectDimensions, semanticSchema.getDimensions()));
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(s2SQL);
|
||||
List<String> selectDimensions = filterDateField(selectFields);
|
||||
parseInfo.setDimensions(matchSchemaElements(selectDimensions, dsSchema.getDimensions()));
|
||||
}
|
||||
}
|
||||
|
||||
private Set<SchemaElement> getElements(Long dataSetId, List<String> allFields, List<SchemaElement> elements) {
|
||||
private Set<SchemaElement> matchSchemaElements(List<String> allFields, Set<SchemaElement> elements) {
|
||||
return elements.stream()
|
||||
.filter(schemaElement -> {
|
||||
if (CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||
return dataSetId.equals(schemaElement.getDataSet()) && allFields.contains(
|
||||
schemaElement.getName());
|
||||
return allFields.contains(schemaElement.getName());
|
||||
}
|
||||
Set<String> allFieldsSet = new HashSet<>(allFields);
|
||||
Set<String> aliasSet = new HashSet<>(schemaElement.getAlias());
|
||||
List<String> intersection = allFieldsSet.stream()
|
||||
.filter(aliasSet::contains).collect(Collectors.toList());
|
||||
return dataSetId.equals(schemaElement.getDataSet()) && (allFields.contains(
|
||||
schemaElement.getName()) || !CollectionUtils.isEmpty(intersection));
|
||||
return allFields.contains(schemaElement.getName())
|
||||
|| !CollectionUtils.isEmpty(intersection);
|
||||
}
|
||||
).collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
private List<String> getFieldsExceptDate(List<String> allFields) {
|
||||
if (org.springframework.util.CollectionUtils.isEmpty(allFields)) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
private List<String> filterDateField(List<String> allFields) {
|
||||
return allFields.stream()
|
||||
.filter(entry -> !TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(entry))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> fieldNameToElement,
|
||||
List<FieldExpression> fieldExpressions) {
|
||||
private List<QueryFilter> extractDimensionFilter(Map<String, SchemaElement> fieldNameToElement,
|
||||
List<FieldExpression> fieldExpressions) {
|
||||
List<QueryFilter> result = Lists.newArrayList();
|
||||
for (FieldExpression expression : fieldExpressions) {
|
||||
QueryFilter dimensionFilter = new QueryFilter();
|
||||
@@ -148,7 +131,7 @@ public class ParseInfoProcessor implements ResultProcessor {
|
||||
return result;
|
||||
}
|
||||
|
||||
private DateConf getDateInfo(List<FieldExpression> fieldExpressions) {
|
||||
private DateConf extractDateFilter(List<FieldExpression> fieldExpressions) {
|
||||
List<FieldExpression> dateExpressions = fieldExpressions.stream()
|
||||
.filter(expression -> TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(expression.getFieldName()))
|
||||
.collect(Collectors.toList());
|
||||
@@ -193,10 +176,9 @@ public class ParseInfoProcessor implements ResultProcessor {
|
||||
return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue());
|
||||
}
|
||||
|
||||
protected Map<String, SchemaElement> getNameToElement(Long dataSetId) {
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
List<SchemaElement> dimensions = semanticSchema.getDimensions(dataSetId);
|
||||
List<SchemaElement> metrics = semanticSchema.getMetrics(dataSetId);
|
||||
protected Map<String, SchemaElement> getNameToElement(DataSetSchema dsSchema) {
|
||||
Set<SchemaElement> dimensions = dsSchema.getDimensions();
|
||||
Set<SchemaElement> metrics = dsSchema.getMetrics();
|
||||
|
||||
List<SchemaElement> allElements = Lists.newArrayList();
|
||||
allElements.addAll(dimensions);
|
||||
@@ -214,7 +196,7 @@ public class ParseInfoProcessor implements ResultProcessor {
|
||||
}
|
||||
return result.stream();
|
||||
})
|
||||
.collect(Collectors.toMap(pair -> pair.getLeft(), pair -> pair.getRight(),
|
||||
.collect(Collectors.toMap(Pair::getLeft, Pair::getRight,
|
||||
(value1, value2) -> value2));
|
||||
}
|
||||
|
||||
|
||||
@@ -33,10 +33,10 @@ import java.util.stream.Collectors;
|
||||
public class ChatWorkflowEngine {
|
||||
|
||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
private List<SchemaMapper> schemaMappers = ComponentFactory.getSchemaMappers();
|
||||
private List<SemanticParser> semanticParsers = ComponentFactory.getSemanticParsers();
|
||||
private List<SemanticCorrector> semanticCorrectors = ComponentFactory.getSemanticCorrectors();
|
||||
private List<ResultProcessor> resultProcessors = ComponentFactory.getResultProcessors();
|
||||
private final List<SchemaMapper> schemaMappers = ComponentFactory.getSchemaMappers();
|
||||
private final List<SemanticParser> semanticParsers = ComponentFactory.getSemanticParsers();
|
||||
private final List<SemanticCorrector> semanticCorrectors = ComponentFactory.getSemanticCorrectors();
|
||||
private final List<ResultProcessor> resultProcessors = ComponentFactory.getResultProcessors();
|
||||
|
||||
public void execute(ChatQueryContext queryCtx, ParseResp parseResult) {
|
||||
queryCtx.setChatWorkflowState(ChatWorkflowState.MAPPING);
|
||||
@@ -44,7 +44,7 @@ public class ChatWorkflowEngine {
|
||||
switch (queryCtx.getChatWorkflowState()) {
|
||||
case MAPPING:
|
||||
performMapping(queryCtx);
|
||||
if (queryCtx.getMapInfo().getMatchedDataSetInfos().size() == 0) {
|
||||
if (queryCtx.getMapInfo().getMatchedDataSetInfos().isEmpty()) {
|
||||
parseResult.setState(ParseResp.ParseState.FAILED);
|
||||
parseResult.setErrorMsg("No semantic entities can be mapped against user question.");
|
||||
queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED);
|
||||
@@ -54,7 +54,7 @@ public class ChatWorkflowEngine {
|
||||
break;
|
||||
case PARSING:
|
||||
performParsing(queryCtx);
|
||||
if (queryCtx.getCandidateQueries().size() == 0) {
|
||||
if (queryCtx.getCandidateQueries().isEmpty()) {
|
||||
parseResult.setState(ParseResp.ParseState.FAILED);
|
||||
parseResult.setErrorMsg("No semantic queries can be parsed out.");
|
||||
queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED);
|
||||
|
||||
@@ -78,7 +78,7 @@ public class QueryReqConverter {
|
||||
querySQLReq.setSql(SqlReplaceHelper.replaceAggAliasOrderItem(querySQLReq.getSql()));
|
||||
log.debug("replaceOrderAggSameAlias {} -> {}", reqSql, querySQLReq.getSql());
|
||||
//4.build MetricTables
|
||||
List<String> allFields = SqlSelectHelper.getAllFields(querySQLReq.getSql());
|
||||
List<String> allFields = SqlSelectHelper.getAllSelectFields(querySQLReq.getSql());
|
||||
List<MetricSchemaResp> metricSchemas = getMetrics(semanticSchemaResp, allFields);
|
||||
List<String> metrics = metricSchemas.stream().map(m -> m.getBizName()).collect(Collectors.toList());
|
||||
QueryStructReq queryStructReq = new QueryStructReq();
|
||||
|
||||
@@ -124,7 +124,7 @@ public class QueryStructUtils {
|
||||
}
|
||||
|
||||
public Set<String> getResName(QuerySqlReq querySqlReq) {
|
||||
return new HashSet<>(SqlSelectHelper.getAllFields(querySqlReq.getSql()));
|
||||
return new HashSet<>(SqlSelectHelper.getAllSelectFields(querySqlReq.getSql()));
|
||||
}
|
||||
|
||||
public Set<String> getBizNameFromSql(QuerySqlReq querySqlReq,
|
||||
|
||||
@@ -141,7 +141,7 @@ public class StatUtils {
|
||||
public void initSqlStatInfo(QuerySqlReq querySqlReq, User facadeUser) {
|
||||
QueryStat queryStatInfo = new QueryStat();
|
||||
List<String> aggFields = SqlSelectHelper.getAggregateFields(querySqlReq.getSql());
|
||||
List<String> allFields = SqlSelectHelper.getAllFields(querySqlReq.getSql());
|
||||
List<String> allFields = SqlSelectHelper.getAllSelectFields(querySqlReq.getSql());
|
||||
List<String> dimensions = allFields.stream().filter(aggFields::contains).collect(Collectors.toList());
|
||||
|
||||
String userName = getUserName(facadeUser);
|
||||
|
||||
Reference in New Issue
Block a user