mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-13 13:07:32 +00:00
[improvement][headless]Deprecate and remove entity-related abstraction and logic.#1876
This commit is contained in:
@@ -1,13 +1,9 @@
|
||||
package com.tencent.supersonic.headless.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -18,9 +14,7 @@ import org.springframework.util.CollectionUtils;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/** Perform SQL corrections on the "Select" section in S2SQL. */
|
||||
@Slf4j
|
||||
@@ -42,13 +36,11 @@ public class SelectCorrector extends BaseSemanticCorrector {
|
||||
&& aggregateFields.size() == selectFields.size()) {
|
||||
return;
|
||||
}
|
||||
correctS2SQL = addFieldsToSelect(chatQueryContext, semanticParseInfo, correctS2SQL);
|
||||
correctS2SQL = addFieldsToSelect(semanticParseInfo, correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
|
||||
}
|
||||
|
||||
protected String addFieldsToSelect(ChatQueryContext chatQueryContext,
|
||||
SemanticParseInfo semanticParseInfo, String correctS2SQL) {
|
||||
correctS2SQL = addTagDefaultFields(chatQueryContext, semanticParseInfo, correctS2SQL);
|
||||
protected String addFieldsToSelect(SemanticParseInfo semanticParseInfo, String correctS2SQL) {
|
||||
|
||||
Set<String> selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL));
|
||||
Set<String> needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL));
|
||||
@@ -70,34 +62,4 @@ public class SelectCorrector extends BaseSemanticCorrector {
|
||||
return addFieldsToSelectSql;
|
||||
}
|
||||
|
||||
private String addTagDefaultFields(ChatQueryContext chatQueryContext,
|
||||
SemanticParseInfo semanticParseInfo, String correctS2SQL) {
|
||||
// If it is in DETAIL mode and select *, add default metrics and dimensions.
|
||||
boolean hasAsterisk = SqlSelectFunctionHelper.hasAsterisk(correctS2SQL);
|
||||
if (!(hasAsterisk && QueryType.DETAIL.equals(semanticParseInfo.getQueryType()))) {
|
||||
return correctS2SQL;
|
||||
}
|
||||
Long dataSetId = semanticParseInfo.getDataSetId();
|
||||
DataSetSchema dataSetSchema =
|
||||
chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
||||
Set<String> needAddDefaultFields = new HashSet<>();
|
||||
if (Objects.nonNull(dataSetSchema)) {
|
||||
if (!CollectionUtils.isEmpty(dataSetSchema.getTagDefaultMetrics())) {
|
||||
Set<String> metrics = dataSetSchema.getTagDefaultMetrics().stream()
|
||||
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
|
||||
needAddDefaultFields.addAll(metrics);
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(dataSetSchema.getTagDefaultDimensions())) {
|
||||
Set<String> dimensions = dataSetSchema.getTagDefaultDimensions().stream()
|
||||
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
|
||||
needAddDefaultFields.addAll(dimensions);
|
||||
}
|
||||
}
|
||||
// remove * in sql and add default fields.
|
||||
if (!CollectionUtils.isEmpty(needAddDefaultFields)) {
|
||||
correctS2SQL =
|
||||
SqlRemoveHelper.removeAsteriskAndAddFields(correctS2SQL, needAddDefaultFields);
|
||||
}
|
||||
return correctS2SQL;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,9 +35,6 @@ public class NatureHelper {
|
||||
case DIMENSION:
|
||||
result = SchemaElementType.DIMENSION;
|
||||
break;
|
||||
case ENTITY:
|
||||
result = SchemaElementType.ENTITY;
|
||||
break;
|
||||
case DATASET:
|
||||
result = SchemaElementType.DATASET;
|
||||
break;
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
package com.tencent.supersonic.headless.chat.mapper;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/** A mapper capable of converting the VALUE of entity dimension values into ID types. */
|
||||
@Slf4j
|
||||
public class EntityMapper extends BaseMapper {
|
||||
|
||||
@Override
|
||||
public void doMap(ChatQueryContext chatQueryContext) {
|
||||
SchemaMapInfo schemaMapInfo = chatQueryContext.getMapInfo();
|
||||
for (Long dataSetId : schemaMapInfo.getMatchedDataSetInfos()) {
|
||||
List<SchemaElementMatch> schemaElementMatchList =
|
||||
schemaMapInfo.getMatchedElements(dataSetId);
|
||||
if (CollectionUtils.isEmpty(schemaElementMatchList)) {
|
||||
continue;
|
||||
}
|
||||
SchemaElement entity = getEntity(dataSetId, chatQueryContext);
|
||||
if (entity == null || entity.getId() == null) {
|
||||
continue;
|
||||
}
|
||||
List<SchemaElementMatch> valueSchemaElements = schemaElementMatchList.stream()
|
||||
.filter(schemaElementMatch -> SchemaElementType.VALUE
|
||||
.equals(schemaElementMatch.getElement().getType()))
|
||||
.collect(Collectors.toList());
|
||||
for (SchemaElementMatch schemaElementMatch : valueSchemaElements) {
|
||||
if (!entity.getId().equals(schemaElementMatch.getElement().getId())) {
|
||||
continue;
|
||||
}
|
||||
if (!checkExistSameEntitySchemaElements(schemaElementMatch,
|
||||
schemaElementMatchList)) {
|
||||
SchemaElementMatch entitySchemaElementMath = new SchemaElementMatch();
|
||||
BeanUtils.copyProperties(schemaElementMatch, entitySchemaElementMath);
|
||||
entitySchemaElementMath.setElement(entity);
|
||||
schemaElementMatchList.add(entitySchemaElementMath);
|
||||
}
|
||||
schemaElementMatch.getElement().setType(SchemaElementType.ID);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private boolean checkExistSameEntitySchemaElements(SchemaElementMatch valueSchemaElementMatch,
|
||||
List<SchemaElementMatch> schemaElementMatchList) {
|
||||
List<SchemaElementMatch> entitySchemaElements = schemaElementMatchList.stream()
|
||||
.filter(schemaElementMatch -> SchemaElementType.ENTITY
|
||||
.equals(schemaElementMatch.getElement().getType()))
|
||||
.collect(Collectors.toList());
|
||||
for (SchemaElementMatch schemaElementMatch : entitySchemaElements) {
|
||||
if (schemaElementMatch.getElement().getId()
|
||||
.equals(valueSchemaElementMatch.getElement().getId())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private SchemaElement getEntity(Long dataSetId, ChatQueryContext chatQueryContext) {
|
||||
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||
DataSetSchema modelSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId);
|
||||
if (modelSchema != null && modelSchema.getEntity() != null) {
|
||||
return modelSchema.getEntity();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
@@ -94,9 +94,8 @@ public class MapFilter {
|
||||
SchemaElement element = schemaElementMatch.getElement();
|
||||
SchemaElementType type = element.getType();
|
||||
|
||||
boolean isEntityOrDatasetOrId = SchemaElementType.ENTITY.equals(type)
|
||||
|| SchemaElementType.DATASET.equals(type)
|
||||
|| SchemaElementType.ID.equals(type);
|
||||
boolean isEntityOrDatasetOrId =
|
||||
SchemaElementType.DATASET.equals(type) || SchemaElementType.ID.equals(type);
|
||||
|
||||
return !isEntityOrDatasetOrId && needRemovePredicate.test(element);
|
||||
});
|
||||
|
||||
@@ -1,29 +1,20 @@
|
||||
package com.tencent.supersonic.headless.chat.parser;
|
||||
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/** QueryTypeParser resolves query type as either AGGREGATE or DETAIL or ID. */
|
||||
/** QueryTypeParser resolves query type as either AGGREGATE or DETAIL */
|
||||
@Slf4j
|
||||
public class QueryTypeParser implements SemanticParser {
|
||||
|
||||
@@ -52,22 +43,6 @@ public class QueryTypeParser implements SemanticParser {
|
||||
return QueryType.DETAIL;
|
||||
}
|
||||
|
||||
// 1. entity queryType
|
||||
Long dataSetId = parseInfo.getDataSetId();
|
||||
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
|
||||
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getParsedS2SQL());
|
||||
List<String> whereFilterByTimeFields = filterByTimeFields(whereFields);
|
||||
if (CollectionUtils.isNotEmpty(whereFilterByTimeFields)) {
|
||||
Set<String> ids = semanticSchema.getEntities(dataSetId).stream()
|
||||
.map(SchemaElement::getName).collect(Collectors.toSet());
|
||||
if (CollectionUtils.isNotEmpty(ids)
|
||||
&& ids.stream().anyMatch(whereFilterByTimeFields::contains)) {
|
||||
return QueryType.ID;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. AGG queryType
|
||||
if (SqlSelectFunctionHelper.hasAggregateFunction(sqlInfo.getParsedS2SQL())) {
|
||||
return QueryType.AGGREGATE;
|
||||
@@ -76,20 +51,4 @@ public class QueryTypeParser implements SemanticParser {
|
||||
return QueryType.DETAIL;
|
||||
}
|
||||
|
||||
private static List<String> filterByTimeFields(List<String> whereFields) {
|
||||
return whereFields.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private static boolean selectContainsMetric(SqlInfo sqlInfo, Long dataSetId,
|
||||
SemanticSchema semanticSchema) {
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getParsedS2SQL());
|
||||
List<SchemaElement> metrics = semanticSchema.getMetrics(dataSetId);
|
||||
if (CollectionUtils.isNotEmpty(metrics)) {
|
||||
Set<String> metricNameSet =
|
||||
metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
|
||||
return selectFields.stream().anyMatch(metricNameSet::contains);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.query.BaseSemanticQuery;
|
||||
@@ -124,18 +123,6 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
SchemaElement element = schemaMatch.getElement();
|
||||
element.setOrder(1 - schemaMatch.getSimilarity());
|
||||
switch (element.getType()) {
|
||||
case ID:
|
||||
SchemaElement entityElement =
|
||||
semanticSchema.getElement(SchemaElementType.ENTITY, element.getId());
|
||||
if (entityElement != null) {
|
||||
if (id2Values.containsKey(element.getId())) {
|
||||
id2Values.get(element.getId()).add(schemaMatch);
|
||||
} else {
|
||||
id2Values.put(element.getId(),
|
||||
new ArrayList<>(Arrays.asList(schemaMatch)));
|
||||
}
|
||||
}
|
||||
break;
|
||||
case VALUE:
|
||||
SchemaElement dimElement =
|
||||
semanticSchema.getElement(SchemaElementType.DIMENSION, element.getId());
|
||||
@@ -154,13 +141,9 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
case METRIC:
|
||||
parseInfo.getMetrics().add(element);
|
||||
break;
|
||||
case ENTITY:
|
||||
parseInfo.setEntity(element);
|
||||
break;
|
||||
default:
|
||||
}
|
||||
}
|
||||
addToFilters(id2Values, parseInfo, semanticSchema, SchemaElementType.ENTITY);
|
||||
addToFilters(dim2Values, parseInfo, semanticSchema, SchemaElementType.DIMENSION);
|
||||
}
|
||||
|
||||
@@ -182,8 +165,6 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
dimensionFilter.setName(dimension.getName());
|
||||
dimensionFilter.setOperator(FilterOperatorEnum.EQUALS);
|
||||
dimensionFilter.setElementID(schemaMatch.getElement().getId());
|
||||
parseInfo.setEntity(
|
||||
semanticSchema.getElement(SchemaElementType.ENTITY, entry.getKey()));
|
||||
parseInfo.getDimensionFilters().add(dimensionFilter);
|
||||
} else {
|
||||
QueryFilter dimensionFilter = new QueryFilter();
|
||||
@@ -216,11 +197,6 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
return convertQueryMultiStruct();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setParseInfo(SemanticParseInfo parseInfo) {
|
||||
this.parseInfo = parseInfo;
|
||||
}
|
||||
|
||||
public static List<RuleSemanticQuery> resolve(Long dataSetId,
|
||||
List<SchemaElementMatch> candidateElementMatches, ChatQueryContext chatQueryContext) {
|
||||
List<RuleSemanticQuery> matchedQueries = new ArrayList<>();
|
||||
@@ -228,7 +204,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
List<SchemaElementMatch> matches =
|
||||
semanticQuery.match(candidateElementMatches, chatQueryContext);
|
||||
|
||||
if (matches.size() > 0) {
|
||||
if (!matches.isEmpty()) {
|
||||
RuleSemanticQuery query =
|
||||
QueryManager.createRuleQuery(semanticQuery.getQueryMode());
|
||||
query.getParseInfo().getElementMatches().addAll(matches);
|
||||
@@ -238,10 +214,6 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
return matchedQueries;
|
||||
}
|
||||
|
||||
protected QueryStructReq convertQueryStruct() {
|
||||
return QueryReqBuilder.buildStructReq(parseInfo);
|
||||
}
|
||||
|
||||
protected QueryMultiStructReq convertQueryMultiStruct() {
|
||||
return QueryReqBuilder.buildMultiStructReq(parseInfo);
|
||||
}
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
package com.tencent.supersonic.headless.chat.query.rule.detail;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ENTITY;
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.VALUE;
|
||||
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class DetailFilterQuery extends DetailListQuery {
|
||||
|
||||
public static final String QUERY_MODE = "DETAIL_LIST_FILTER";
|
||||
|
||||
public DetailFilterQuery() {
|
||||
super();
|
||||
queryMatcher.addOption(VALUE, REQUIRED, AT_LEAST, 1);
|
||||
queryMatcher.addOption(ENTITY, REQUIRED, AT_LEAST, 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getQueryMode() {
|
||||
return QUERY_MODE;
|
||||
}
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
package com.tencent.supersonic.headless.chat.query.rule.detail;
|
||||
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ID;
|
||||
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
|
||||
@Component
|
||||
public class DetailIdQuery extends DetailListQuery {
|
||||
|
||||
public static final String QUERY_MODE = "DETAIL_ID";
|
||||
|
||||
public DetailIdQuery() {
|
||||
super();
|
||||
queryMatcher.addOption(ID, REQUIRED, AT_LEAST, 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getQueryMode() {
|
||||
return QUERY_MODE;
|
||||
}
|
||||
}
|
||||
@@ -1,65 +0,0 @@
|
||||
package com.tencent.supersonic.headless.chat.query.rule.detail;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.Order;
|
||||
import com.tencent.supersonic.headless.api.pojo.*;
|
||||
import com.tencent.supersonic.headless.api.pojo.DetailTypeDefaultConfig;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public abstract class DetailListQuery extends DetailSemanticQuery {
|
||||
|
||||
@Override
|
||||
public void fillParseInfo(ChatQueryContext chatQueryContext) {
|
||||
super.fillParseInfo(chatQueryContext);
|
||||
this.addEntityDetailAndOrderByMetric(chatQueryContext, parseInfo);
|
||||
}
|
||||
|
||||
private void addEntityDetailAndOrderByMetric(ChatQueryContext chatQueryContext,
|
||||
SemanticParseInfo parseInfo) {
|
||||
Long dataSetId = parseInfo.getDataSetId();
|
||||
if (Objects.isNull(dataSetId) || dataSetId <= 0L) {
|
||||
return;
|
||||
}
|
||||
DataSetSchema dataSetSchema =
|
||||
chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
||||
if (dataSetSchema != null && Objects.nonNull(dataSetSchema.getEntity())) {
|
||||
Set<SchemaElement> dimensions = new LinkedHashSet<>();
|
||||
Set<SchemaElement> metrics = new LinkedHashSet<>();
|
||||
Set<Order> orders = new LinkedHashSet<>();
|
||||
DetailTypeDefaultConfig detailTypeDefaultConfig =
|
||||
dataSetSchema.getTagTypeDefaultConfig();
|
||||
if (detailTypeDefaultConfig != null
|
||||
&& detailTypeDefaultConfig.getDefaultDisplayInfo() != null) {
|
||||
if (CollectionUtils.isNotEmpty(
|
||||
detailTypeDefaultConfig.getDefaultDisplayInfo().getMetricIds())) {
|
||||
metrics = detailTypeDefaultConfig.getDefaultDisplayInfo().getMetricIds()
|
||||
.stream().map(id -> {
|
||||
SchemaElement metric =
|
||||
dataSetSchema.getElement(SchemaElementType.METRIC, id);
|
||||
if (metric != null) {
|
||||
orders.add(
|
||||
new Order(metric.getBizName(), Constants.DESC_UPPER));
|
||||
}
|
||||
return metric;
|
||||
}).filter(Objects::nonNull).collect(Collectors.toSet());
|
||||
}
|
||||
if (CollectionUtils.isNotEmpty(
|
||||
detailTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds())) {
|
||||
dimensions = detailTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds()
|
||||
.stream()
|
||||
.map(id -> dataSetSchema.getElement(SchemaElementType.DIMENSION, id))
|
||||
.filter(Objects::nonNull).collect(Collectors.toSet());
|
||||
}
|
||||
}
|
||||
parseInfo.setDimensions(dimensions);
|
||||
parseInfo.setMetrics(metrics);
|
||||
parseInfo.setOrders(orders);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -41,7 +41,7 @@ public abstract class DetailSemanticQuery extends RuleSemanticQuery {
|
||||
Map<Long, DataSetSchema> dataSetSchemaMap =
|
||||
chatQueryContext.getSemanticSchema().getDataSetSchemaMap();
|
||||
DataSetSchema dataSetSchema = dataSetSchemaMap.get(parseInfo.getDataSetId());
|
||||
TimeDefaultConfig timeDefaultConfig = dataSetSchema.getTagTypeTimeDefaultConfig();
|
||||
TimeDefaultConfig timeDefaultConfig = dataSetSchema.getDetailTypeTimeDefaultConfig();
|
||||
|
||||
if (Objects.nonNull(timeDefaultConfig) && Objects.nonNull(timeDefaultConfig.getUnit())
|
||||
&& timeDefaultConfig.getUnit() != -1) {
|
||||
|
||||
@@ -1,89 +0,0 @@
|
||||
package com.tencent.supersonic.headless.chat.query.rule.metric;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Filter;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterType;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ENTITY;
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ID;
|
||||
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class MetricIdQuery extends MetricSemanticQuery {
|
||||
|
||||
public static final String QUERY_MODE = "METRIC_ID";
|
||||
|
||||
public MetricIdQuery() {
|
||||
super();
|
||||
queryMatcher.addOption(ID, REQUIRED, AT_LEAST, 1).addOption(ENTITY, REQUIRED, AT_LEAST, 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getQueryMode() {
|
||||
return QUERY_MODE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SemanticQueryReq buildSemanticQueryReq() {
|
||||
if (!isMultiStructQuery()) {
|
||||
return super.buildSemanticQueryReq();
|
||||
}
|
||||
return super.multiStructExecute();
|
||||
}
|
||||
|
||||
protected boolean isMultiStructQuery() {
|
||||
Set<String> filterBizName = new HashSet<>();
|
||||
parseInfo.getDimensionFilters().stream().filter(filter -> filter.getElementID() != null)
|
||||
.forEach(filter -> filterBizName.add(filter.getBizName()));
|
||||
return FilterType.UNION.equals(parseInfo.getFilterType()) && filterBizName.size() > 1;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected QueryStructReq convertQueryStruct() {
|
||||
QueryStructReq queryStructReq = super.convertQueryStruct();
|
||||
addDimension(queryStructReq, true);
|
||||
return queryStructReq;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected QueryMultiStructReq convertQueryMultiStruct() {
|
||||
QueryMultiStructReq queryMultiStructReq = super.convertQueryMultiStruct();
|
||||
for (QueryStructReq queryStructReq : queryMultiStructReq.getQueryStructReqs()) {
|
||||
addDimension(queryStructReq, false);
|
||||
}
|
||||
return queryMultiStructReq;
|
||||
}
|
||||
|
||||
private void addDimension(QueryStructReq queryStructReq, boolean onlyOperateInFilter) {
|
||||
if (!queryStructReq.getDimensionFilters().isEmpty()) {
|
||||
List<String> dimensions = queryStructReq.getGroups();
|
||||
log.info("addDimension before [{}]", queryStructReq.getGroups());
|
||||
List<Filter> filters = new ArrayList<>(queryStructReq.getDimensionFilters());
|
||||
if (onlyOperateInFilter) {
|
||||
filters = filters.stream()
|
||||
.filter(filter -> filter.getOperator().equals(FilterOperatorEnum.IN))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
filters.forEach(d -> {
|
||||
if (!dimensions.contains(d.getBizName())) {
|
||||
dimensions.add(d.getBizName());
|
||||
}
|
||||
});
|
||||
queryStructReq.setGroups(dimensions);
|
||||
log.info("addDimension after [{}]", queryStructReq.getGroups());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,108 +0,0 @@
|
||||
package com.tencent.supersonic.headless.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.*;
|
||||
import com.tencent.supersonic.headless.api.pojo.DetailTypeDefaultConfig;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.mockito.MockedStatic;
|
||||
import org.mockito.Mockito;
|
||||
import org.springframework.core.env.Environment;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
class SelectCorrectorTest {
|
||||
|
||||
Long dataSetId = 2L;
|
||||
|
||||
@Test
|
||||
void testDoCorrect() {
|
||||
MockedStatic<ContextUtils> mocked = Mockito.mockStatic(ContextUtils.class);
|
||||
Environment mockEnvironment = Mockito.mock(Environment.class);
|
||||
mocked.when(() -> ContextUtils.getBean(Environment.class)).thenReturn(mockEnvironment);
|
||||
when(mockEnvironment.getProperty(SelectCorrector.ADDITIONAL_INFORMATION)).thenReturn("");
|
||||
|
||||
BaseSemanticCorrector corrector = new SelectCorrector();
|
||||
ChatQueryContext chatQueryContext = buildQueryContext(dataSetId);
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
SchemaElement dataSet = new SchemaElement();
|
||||
dataSet.setDataSetId(dataSetId);
|
||||
semanticParseInfo.setDataSet(dataSet);
|
||||
semanticParseInfo.setQueryType(QueryType.DETAIL);
|
||||
SqlInfo sqlInfo = new SqlInfo();
|
||||
String sql = "SELECT * FROM 艺人库 WHERE 艺人名='周杰伦'";
|
||||
sqlInfo.setParsedS2SQL(sql);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||
corrector.correct(chatQueryContext, semanticParseInfo);
|
||||
Assert.assertEquals("SELECT 粉丝数, 国籍, 艺人名, 性别 FROM 艺人库 WHERE 艺人名 = '周杰伦'",
|
||||
semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
|
||||
}
|
||||
|
||||
private ChatQueryContext buildQueryContext(Long dataSetId) {
|
||||
ChatQueryContext chatQueryContext = new ChatQueryContext();
|
||||
List<DataSetSchema> dataSetSchemaList = new ArrayList<>();
|
||||
DataSetSchema dataSetSchema = new DataSetSchema();
|
||||
QueryConfig queryConfig = new QueryConfig();
|
||||
DetailTypeDefaultConfig detailTypeDefaultConfig = new DetailTypeDefaultConfig();
|
||||
DefaultDisplayInfo defaultDisplayInfo = new DefaultDisplayInfo();
|
||||
List<Long> dimensionIds = new ArrayList<>();
|
||||
dimensionIds.add(1L);
|
||||
dimensionIds.add(2L);
|
||||
dimensionIds.add(3L);
|
||||
defaultDisplayInfo.setDimensionIds(dimensionIds);
|
||||
|
||||
List<Long> metricIds = new ArrayList<>();
|
||||
metricIds.add(4L);
|
||||
defaultDisplayInfo.setMetricIds(metricIds);
|
||||
|
||||
detailTypeDefaultConfig.setDefaultDisplayInfo(defaultDisplayInfo);
|
||||
queryConfig.setDetailTypeDefaultConfig(detailTypeDefaultConfig);
|
||||
|
||||
dataSetSchema.setQueryConfig(queryConfig);
|
||||
SchemaElement schemaElement = new SchemaElement();
|
||||
schemaElement.setDataSetId(dataSetId);
|
||||
dataSetSchema.setDataSet(schemaElement);
|
||||
Set<SchemaElement> dimensions = new HashSet<>();
|
||||
SchemaElement element1 = new SchemaElement();
|
||||
element1.setDataSetId(dataSetId);
|
||||
element1.setId(1L);
|
||||
element1.setName("艺人名");
|
||||
dimensions.add(element1);
|
||||
|
||||
SchemaElement element2 = new SchemaElement();
|
||||
element2.setDataSetId(dataSetId);
|
||||
element2.setId(2L);
|
||||
element2.setName("性别");
|
||||
dimensions.add(element2);
|
||||
|
||||
SchemaElement element3 = new SchemaElement();
|
||||
element3.setDataSetId(dataSetId);
|
||||
element3.setId(3L);
|
||||
element3.setName("国籍");
|
||||
dimensions.add(element3);
|
||||
|
||||
dataSetSchema.setDimensions(dimensions);
|
||||
|
||||
Set<SchemaElement> metrics = new HashSet<>();
|
||||
SchemaElement metric1 = new SchemaElement();
|
||||
metric1.setDataSetId(dataSetId);
|
||||
metric1.setId(4L);
|
||||
metric1.setName("粉丝数");
|
||||
metrics.add(metric1);
|
||||
|
||||
dataSetSchema.setMetrics(metrics);
|
||||
dataSetSchemaList.add(dataSetSchema);
|
||||
|
||||
SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList);
|
||||
chatQueryContext.setSemanticSchema(semanticSchema);
|
||||
return chatQueryContext;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user