[improvement][chat]Refactor code logic in rule-based parsing.

This commit is contained in:
Jun Zhang
2024-11-09 15:00:08 +08:00
committed by jerryjzhang
parent d4a9d5a7e6
commit e0e167fd40
18 changed files with 154 additions and 233 deletions

View File

@@ -24,6 +24,10 @@ public class DataSetSchema implements Serializable {
private Set<SchemaElement> terms = new HashSet<>();
private QueryConfig queryConfig;
public Long getDataSetId() {
return dataSet.getDataSetId();
}
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
Optional<SchemaElement> element = Optional.empty();

View File

@@ -119,22 +119,26 @@ public class SemanticSchema implements Serializable {
return getElementsById(dataSetId, dataSets).orElse(null);
}
public QueryConfig getQueryConfig(Long dataSetId) {
DataSetSchema first = dataSetSchemaList.stream().filter(
dataSetSchema -> dataSetId.equals(dataSetSchema.getDataSet().getDataSetId()))
.findFirst().orElse(null);
if (Objects.nonNull(first)) {
return first.getQueryConfig();
}
return null;
}
public List<SchemaElement> getDataSets() {
List<SchemaElement> dataSets = new ArrayList<>();
dataSetSchemaList.forEach(d -> dataSets.add(d.getDataSet()));
return dataSets;
}
public DataSetSchema getDataSetSchema(Long dataSetId) {
return dataSetSchemaList.stream()
.filter(dataSetSchema -> dataSetId.equals(dataSetSchema.getDataSetId())).findFirst()
.orElse(null);
}
public QueryConfig getQueryConfig(Long dataSetId) {
DataSetSchema dataSetSchema = getDataSetSchema(dataSetId);
if (Objects.nonNull(dataSetSchema)) {
return dataSetSchema.getQueryConfig();
}
return null;
}
public Map<Long, DataSetSchema> getDataSetSchemaMap() {
if (CollectionUtils.isEmpty(dataSetSchemaList)) {
return new HashMap<>();

View File

@@ -188,36 +188,31 @@ public class QueryStructReq extends SemanticQueryReq {
List<Aggregator> aggregators = queryStructReq.getAggregators();
if (!CollectionUtils.isEmpty(aggregators)) {
for (Aggregator aggregator : aggregators) {
selectItems.add(buildAggregatorSelectItem(aggregator, queryStructReq));
selectItems.add(buildAggregatorSelectItem(aggregator));
}
}
return selectItems;
}
private SelectItem buildAggregatorSelectItem(Aggregator aggregator,
QueryStructReq queryStructReq) {
private SelectItem buildAggregatorSelectItem(Aggregator aggregator) {
String columnName = aggregator.getColumn();
if (queryStructReq.getQueryType().isNativeAggQuery()) {
return new SelectItem(new Column(columnName));
} else {
Function function = new Function();
AggOperatorEnum func = aggregator.getFunc();
if (AggOperatorEnum.UNKNOWN.equals(func)) {
func = AggOperatorEnum.SUM;
}
function.setName(func.getOperator());
if (AggOperatorEnum.COUNT_DISTINCT.equals(func)) {
function.setName("count");
function.setDistinct(true);
}
function.setParameters(new ExpressionList(new Column(columnName)));
SelectItem selectExpressionItem = new SelectItem(function);
String alias = StringUtils.isNotBlank(aggregator.getAlias()) ? aggregator.getAlias()
: columnName;
selectExpressionItem.setAlias(new Alias(alias));
return selectExpressionItem;
Function function = new Function();
AggOperatorEnum func = aggregator.getFunc();
if (AggOperatorEnum.UNKNOWN.equals(func)) {
func = AggOperatorEnum.SUM;
}
function.setName(func.getOperator());
if (AggOperatorEnum.COUNT_DISTINCT.equals(func)) {
function.setName("count");
function.setDistinct(true);
}
function.setParameters(new ExpressionList(new Column(columnName)));
SelectItem selectExpressionItem = new SelectItem(function);
String alias =
StringUtils.isNotBlank(aggregator.getAlias()) ? aggregator.getAlias() : columnName;
selectExpressionItem.setAlias(new Alias(alias));
return selectExpressionItem;
}
private List<OrderByElement> buildOrderByElements(QueryStructReq queryStructReq) {
@@ -241,7 +236,7 @@ public class QueryStructReq extends SemanticQueryReq {
private GroupByElement buildGroupByElement(QueryStructReq queryStructReq) {
List<String> groups = queryStructReq.getGroups();
if (!CollectionUtils.isEmpty(groups) && !queryStructReq.getQueryType().isNativeAggQuery()) {
if (!CollectionUtils.isEmpty(groups) && !queryStructReq.getAggregators().isEmpty()) {
GroupByElement groupByElement = new GroupByElement();
for (String group : groups) {
groupByElement.addGroupByExpression(new Column(group));

View File

@@ -41,6 +41,10 @@ public class ChatQueryContext implements Serializable {
}
}
public DataSetSchema getDataSetSchema(Long dataSetId) {
return semanticSchema.getDataSetSchema(dataSetId);
}
public List<SemanticQuery> getCandidateQueries() {
candidateQueries = candidateQueries.stream()
.sorted(Comparator.comparing(

View File

@@ -1,18 +1,10 @@
package com.tencent.supersonic.headless.chat.parser;
import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import java.util.List;
import java.util.Objects;
/** QueryTypeParser resolves query type as either AGGREGATE or DETAIL */
@Slf4j
@@ -20,34 +12,17 @@ public class QueryTypeParser implements SemanticParser {
@Override
public void parse(ChatQueryContext chatQueryContext) {
chatQueryContext.getCandidateQueries().forEach(query -> {
SemanticParseInfo parseInfo = query.getParseInfo();
String s2SQL = parseInfo.getSqlInfo().getParsedS2SQL();
QueryType queryType = QueryType.DETAIL;
List<SemanticQuery> candidateQueries = chatQueryContext.getCandidateQueries();
User user = chatQueryContext.getRequest().getUser();
if (SqlSelectFunctionHelper.hasAggregateFunction(s2SQL)) {
queryType = QueryType.AGGREGATE;
}
for (SemanticQuery semanticQuery : candidateQueries) {
// 1.init S2SQL
Long dataSetId = semanticQuery.getParseInfo().getDataSetId();
DataSetSchema dataSetSchema =
chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
semanticQuery.initS2Sql(dataSetSchema, user);
// 2.set queryType
QueryType queryType = getQueryType(semanticQuery);
semanticQuery.getParseInfo().setQueryType(queryType);
}
}
private QueryType getQueryType(SemanticQuery semanticQuery) {
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
SqlInfo sqlInfo = parseInfo.getSqlInfo();
if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getParsedS2SQL())) {
return QueryType.DETAIL;
}
if (SqlSelectFunctionHelper.hasAggregateFunction(sqlInfo.getParsedS2SQL())) {
return QueryType.AGGREGATE;
}
return QueryType.DETAIL;
parseInfo.setQueryType(queryType);
});
}
}

View File

@@ -49,6 +49,7 @@ public class LLMResponseService {
parseInfo.setScore(queryCtx.getRequest().getQueryText().length() * (1 + weight));
parseInfo.setQueryMode(semanticQuery.getQueryMode());
parseInfo.getSqlInfo().setParsedS2SQL(s2SQL);
parseInfo.getSqlInfo().setCorrectedS2SQL(s2SQL);
queryCtx.getCandidateQueries().add(semanticQuery);
}

View File

@@ -34,15 +34,13 @@ public class RuleSqlParser implements SemanticParser {
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(dataSetId);
List<RuleSemanticQuery> queries =
RuleSemanticQuery.resolve(dataSetId, elementMatches, chatQueryContext);
for (RuleSemanticQuery query : queries) {
query.fillParseInfo(chatQueryContext);
chatQueryContext.getCandidateQueries().add(query);
}
candidateQueries.addAll(chatQueryContext.getCandidateQueries());
chatQueryContext.getCandidateQueries().clear();
candidateQueries.addAll(queries);
}
chatQueryContext.setCandidateQueries(candidateQueries);
auxiliaryParsers.forEach(p -> p.parse(chatQueryContext));
candidateQueries.forEach(query -> query.buildS2Sql(
chatQueryContext.getDataSetSchema(query.getParseInfo().getDataSetId())));
}
}

View File

@@ -1,87 +1,24 @@
package com.tencent.supersonic.headless.chat.query;
import com.tencent.supersonic.common.pojo.Aggregator;
import com.tencent.supersonic.common.pojo.Filter;
import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
import lombok.Data;
import lombok.ToString;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import java.io.Serializable;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@Slf4j
@ToString
@Data
public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
protected SemanticParseInfo parseInfo = new SemanticParseInfo();
@Override
public SemanticParseInfo getParseInfo() {
return parseInfo;
}
@Override
public void setParseInfo(SemanticParseInfo parseInfo) {
this.parseInfo = parseInfo;
}
protected QueryStructReq convertQueryStruct() {
return QueryReqBuilder.buildStructReq(parseInfo);
}
@Override
public SemanticQueryReq buildSemanticQueryReq() {
return QueryReqBuilder.buildS2SQLReq(parseInfo.getSqlInfo(), parseInfo.getDataSetId());
}
protected void initS2SqlByStruct(DataSetSchema dataSetSchema) {
QueryStructReq queryStructReq = convertQueryStruct();
convertBizNameToName(dataSetSchema, queryStructReq);
QuerySqlReq querySQLReq = queryStructReq.convert();
parseInfo.getSqlInfo().setParsedS2SQL(querySQLReq.getSql());
parseInfo.getSqlInfo().setCorrectedS2SQL(querySQLReq.getSql());
}
protected void convertBizNameToName(DataSetSchema dataSetSchema,
QueryStructReq queryStructReq) {
Map<String, String> bizNameToName = dataSetSchema.getBizNameToName();
bizNameToName.putAll(TimeDimensionEnum.getNameToNameMap());
List<Order> orders = queryStructReq.getOrders();
if (CollectionUtils.isNotEmpty(orders)) {
for (Order order : orders) {
order.setColumn(bizNameToName.get(order.getColumn()));
}
}
List<Aggregator> aggregators = queryStructReq.getAggregators();
if (CollectionUtils.isNotEmpty(aggregators)) {
for (Aggregator aggregator : aggregators) {
aggregator.setColumn(bizNameToName.get(aggregator.getColumn()));
}
}
List<String> groups = queryStructReq.getGroups();
if (CollectionUtils.isNotEmpty(groups)) {
groups = groups.stream().map(bizNameToName::get).collect(Collectors.toList());
queryStructReq.setGroups(groups);
}
List<Filter> dimensionFilters = queryStructReq.getDimensionFilters();
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
dimensionFilters
.forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
}
List<Filter> metricFilters = queryStructReq.getMetricFilters();
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
metricFilters.forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
}
}
}

View File

@@ -1,6 +1,5 @@
package com.tencent.supersonic.headless.chat.query;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
@@ -13,7 +12,7 @@ public interface SemanticQuery {
SemanticQueryReq buildSemanticQueryReq() throws SqlParseException;
void initS2Sql(DataSetSchema dataSetSchema, User user);
void buildS2Sql(DataSetSchema dataSetSchema);
SemanticParseInfo getParseInfo();

View File

@@ -1,6 +1,5 @@
package com.tencent.supersonic.headless.chat.query.llm.s2sql;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.chat.query.QueryManager;
@@ -24,7 +23,7 @@ public class LLMSqlQuery extends LLMSemanticQuery {
}
@Override
public void initS2Sql(DataSetSchema dataSetSchema, User user) {
public void buildS2Sql(DataSetSchema dataSetSchema) {
SqlInfo sqlInfo = parseInfo.getSqlInfo();
sqlInfo.setCorrectedS2SQL(sqlInfo.getParsedS2SQL());
}

View File

@@ -1,7 +1,10 @@
package com.tencent.supersonic.headless.chat.query.rule;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.pojo.Aggregator;
import com.tencent.supersonic.common.pojo.Filter;
import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
@@ -10,6 +13,8 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.query.BaseSemanticQuery;
@@ -17,6 +22,7 @@ import com.tencent.supersonic.headless.chat.query.QueryManager;
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
import lombok.ToString;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import java.util.ArrayList;
@@ -26,7 +32,6 @@ import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.TERM;
@@ -50,14 +55,24 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
}
@Override
public void initS2Sql(DataSetSchema dataSetSchema, User user) {
initS2SqlByStruct(dataSetSchema);
public void buildS2Sql(DataSetSchema dataSetSchema) {
QueryStructReq queryStructReq = convertQueryStruct();
convertBizNameToName(dataSetSchema, queryStructReq);
QuerySqlReq querySQLReq = queryStructReq.convert();
parseInfo.getSqlInfo().setParsedS2SQL(querySQLReq.getSql());
parseInfo.getSqlInfo().setCorrectedS2SQL(querySQLReq.getSql());
}
public void fillParseInfo(ChatQueryContext chatQueryContext) {
parseInfo.setQueryMode(getQueryMode());
protected QueryStructReq convertQueryStruct() {
return QueryReqBuilder.buildStructReq(parseInfo);
}
protected void fillParseInfo(ChatQueryContext chatQueryContext, Long dataSetId) {
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
parseInfo.setQueryMode(getQueryMode());
parseInfo.setDataSet(semanticSchema.getDataSet(dataSetId));
parseInfo.setQueryConfig(semanticSchema.getQueryConfig(dataSetId));
fillSchemaElement(parseInfo, semanticSchema);
fillScore(parseInfo);
fillDateConfByInherited(parseInfo, chatQueryContext);
@@ -110,12 +125,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
}
private void fillSchemaElement(SemanticParseInfo parseInfo, SemanticSchema semanticSchema) {
Set<Long> dataSetIds =
parseInfo.getElementMatches().stream().map(SchemaElementMatch::getElement)
.map(SchemaElement::getDataSetId).collect(Collectors.toSet());
Long dataSetId = dataSetIds.iterator().next();
parseInfo.setDataSet(semanticSchema.getDataSet(dataSetId));
parseInfo.setQueryConfig(semanticSchema.getQueryConfig(dataSetId));
Map<Long, List<SchemaElementMatch>> dim2Values = new HashMap<>();
for (SchemaElementMatch schemaMatch : parseInfo.getElementMatches()) {
@@ -200,14 +210,15 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
public static List<RuleSemanticQuery> resolve(Long dataSetId,
List<SchemaElementMatch> candidateElementMatches, ChatQueryContext chatQueryContext) {
List<RuleSemanticQuery> matchedQueries = new ArrayList<>();
for (RuleSemanticQuery semanticQuery : QueryManager.getRuleQueries()) {
List<SchemaElementMatch> matches =
semanticQuery.match(candidateElementMatches, chatQueryContext);
if (!matches.isEmpty()) {
RuleSemanticQuery query =
QueryManager.createRuleQuery(semanticQuery.getQueryMode());
query.getParseInfo().getElementMatches().addAll(matches);
query.fillParseInfo(chatQueryContext, dataSetId);
matchedQueries.add(query);
}
}
@@ -217,4 +228,39 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
protected QueryMultiStructReq convertQueryMultiStruct() {
return QueryReqBuilder.buildMultiStructReq(parseInfo);
}
protected void convertBizNameToName(DataSetSchema dataSetSchema,
QueryStructReq queryStructReq) {
Map<String, String> bizNameToName = dataSetSchema.getBizNameToName();
bizNameToName.putAll(TimeDimensionEnum.getNameToNameMap());
List<Order> orders = queryStructReq.getOrders();
if (CollectionUtils.isNotEmpty(orders)) {
for (Order order : orders) {
order.setColumn(bizNameToName.get(order.getColumn()));
}
}
List<Aggregator> aggregators = queryStructReq.getAggregators();
if (CollectionUtils.isNotEmpty(aggregators)) {
for (Aggregator aggregator : aggregators) {
aggregator.setColumn(bizNameToName.get(aggregator.getColumn()));
}
}
List<String> groups = queryStructReq.getGroups();
if (CollectionUtils.isNotEmpty(groups)) {
groups = groups.stream().map(bizNameToName::get).collect(Collectors.toList());
queryStructReq.setGroups(groups);
}
List<Filter> dimensionFilters = queryStructReq.getDimensionFilters();
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
dimensionFilters
.forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
}
List<Filter> metricFilters = queryStructReq.getMetricFilters();
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
metricFilters.forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
}
}
}

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.headless.chat.query.rule.detail;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeMode;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
@@ -29,12 +28,10 @@ public abstract class DetailSemanticQuery extends RuleSemanticQuery {
}
@Override
public void fillParseInfo(ChatQueryContext chatQueryContext) {
super.fillParseInfo(chatQueryContext);
public void fillParseInfo(ChatQueryContext chatQueryContext, Long dataSetId) {
super.fillParseInfo(chatQueryContext, dataSetId);
parseInfo.setQueryType(QueryType.DETAIL);
parseInfo.setLimit(parseInfo.getDetailLimit());
if (!needFillDateConf(chatQueryContext)) {
return;
}

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.headless.chat.query.rule.detail;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import org.springframework.stereotype.Component;
@@ -31,16 +31,14 @@ public class DetailValueQuery extends DetailSemanticQuery {
}
@Override
public void fillParseInfo(ChatQueryContext chatQueryContext) {
super.fillParseInfo(chatQueryContext);
public void fillParseInfo(ChatQueryContext chatQueryContext, Long dataSetId) {
super.fillParseInfo(chatQueryContext, dataSetId);
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
parseInfo.getDimensions().addAll(semanticSchema.getDimensions());
parseInfo.getDimensions().forEach(d -> {
parseInfo.getElementMatches()
.add(SchemaElementMatch.builder().element(d).word(d.getName()).similarity(0)
.isInherited(false).detectWord(d.getName()).build());
});
DataSetSchema dataSetSchema = chatQueryContext.getDataSetSchema(dataSetId);
parseInfo.getDimensions().addAll(dataSetSchema.getDimensions());
parseInfo.getDimensions().forEach(
d -> parseInfo.getElementMatches().add(SchemaElementMatch.builder().element(d)
.word(d.getName()).similarity(0).detectWord(d.getName()).build()));
}

View File

@@ -32,8 +32,9 @@ public abstract class MetricSemanticQuery extends RuleSemanticQuery {
}
@Override
public void fillParseInfo(ChatQueryContext chatQueryContext) {
super.fillParseInfo(chatQueryContext);
public void fillParseInfo(ChatQueryContext chatQueryContext, Long dataSetId) {
super.fillParseInfo(chatQueryContext, dataSetId);
parseInfo.setLimit(parseInfo.getMetricLimit());
fillDateInfo(chatQueryContext);
}

View File

@@ -48,8 +48,8 @@ public class MetricTopNQuery extends MetricSemanticQuery {
}
@Override
public void fillParseInfo(ChatQueryContext chatQueryContext) {
super.fillParseInfo(chatQueryContext);
public void fillParseInfo(ChatQueryContext chatQueryContext, Long dataSetId) {
super.fillParseInfo(chatQueryContext, dataSetId);
parseInfo.setScore(parseInfo.getScore() + 2.0);
parseInfo.setAggType(AggregateTypeEnum.SUM);

View File

@@ -17,14 +17,12 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.chat.query.QueryManager;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashSet;
@@ -51,12 +49,6 @@ public class QueryReqBuilder {
chatFilter.getOperator(), chatFilter.getValue()))
.collect(Collectors.toList());
queryStructReq.setMetricFilters(metricFilters);
addDateDimension(parseInfo);
if (isDateFieldAlreadyPresent(parseInfo, getDateField(parseInfo.getDateInfo()))) {
parseInfo.getDimensions().removeIf(schemaElement -> schemaElement.isPartitionTime());
}
queryStructReq.setGroups(parseInfo.getDimensions().stream().map(SchemaElement::getBizName)
.collect(Collectors.toList()));
queryStructReq.setLimit(parseInfo.getLimit());
@@ -155,51 +147,6 @@ public class QueryReqBuilder {
return aggregateType.name();
}
private static void addDateDimension(SemanticParseInfo parseInfo) {
if (parseInfo == null || parseInfo.getDateInfo() == null) {
return;
}
if (shouldSkipAddingDateDimension(parseInfo)) {
return;
}
String dateField = getDateField(parseInfo.getDateInfo());
if (isDateFieldAlreadyPresent(parseInfo, dateField)) {
return;
}
SchemaElement dimension = new SchemaElement();
dimension.setBizName(dateField);
if (QueryManager.isMetricQuery(parseInfo.getQueryMode())) {
addDimension(parseInfo, dimension);
}
}
private static boolean shouldSkipAddingDateDimension(SemanticParseInfo parseInfo) {
return parseInfo.getAggType() != null
&& (parseInfo.getAggType().equals(AggregateTypeEnum.MAX)
|| parseInfo.getAggType().equals(AggregateTypeEnum.MIN))
&& !CollectionUtils.isEmpty(parseInfo.getDimensions());
}
private static boolean isDateFieldAlreadyPresent(SemanticParseInfo parseInfo,
String dateField) {
return parseInfo.getDimensions().stream()
.anyMatch(dimension -> dimension.getBizName().equalsIgnoreCase(dateField));
}
private static void addDimension(SemanticParseInfo parseInfo, SchemaElement dimension) {
List<String> timeDimensions = Arrays.asList(TimeDimensionEnum.DAY.getName(),
TimeDimensionEnum.WEEK.getName(), TimeDimensionEnum.MONTH.getName());
Set<SchemaElement> dimensions = parseInfo.getDimensions().stream()
.filter(d -> !timeDimensions.contains(d.getBizName().toLowerCase()))
.collect(Collectors.toSet());
dimensions.add(dimension);
parseInfo.setDimensions(dimensions);
}
public static Set<Order> getOrder(Set<Order> existingOrders, AggregateTypeEnum aggregator,
SchemaElement metric) {
if (existingOrders != null && !existingOrders.isEmpty()) {