mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
[improvement][chat]Refactor code logic in rule-based parsing.
This commit is contained in:
@@ -24,6 +24,10 @@ public class DataSetSchema implements Serializable {
|
||||
private Set<SchemaElement> terms = new HashSet<>();
|
||||
private QueryConfig queryConfig;
|
||||
|
||||
public Long getDataSetId() {
|
||||
return dataSet.getDataSetId();
|
||||
}
|
||||
|
||||
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
|
||||
Optional<SchemaElement> element = Optional.empty();
|
||||
|
||||
|
||||
@@ -119,22 +119,26 @@ public class SemanticSchema implements Serializable {
|
||||
return getElementsById(dataSetId, dataSets).orElse(null);
|
||||
}
|
||||
|
||||
public QueryConfig getQueryConfig(Long dataSetId) {
|
||||
DataSetSchema first = dataSetSchemaList.stream().filter(
|
||||
dataSetSchema -> dataSetId.equals(dataSetSchema.getDataSet().getDataSetId()))
|
||||
.findFirst().orElse(null);
|
||||
if (Objects.nonNull(first)) {
|
||||
return first.getQueryConfig();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getDataSets() {
|
||||
List<SchemaElement> dataSets = new ArrayList<>();
|
||||
dataSetSchemaList.forEach(d -> dataSets.add(d.getDataSet()));
|
||||
return dataSets;
|
||||
}
|
||||
|
||||
public DataSetSchema getDataSetSchema(Long dataSetId) {
|
||||
return dataSetSchemaList.stream()
|
||||
.filter(dataSetSchema -> dataSetId.equals(dataSetSchema.getDataSetId())).findFirst()
|
||||
.orElse(null);
|
||||
}
|
||||
|
||||
public QueryConfig getQueryConfig(Long dataSetId) {
|
||||
DataSetSchema dataSetSchema = getDataSetSchema(dataSetId);
|
||||
if (Objects.nonNull(dataSetSchema)) {
|
||||
return dataSetSchema.getQueryConfig();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public Map<Long, DataSetSchema> getDataSetSchemaMap() {
|
||||
if (CollectionUtils.isEmpty(dataSetSchemaList)) {
|
||||
return new HashMap<>();
|
||||
|
||||
@@ -188,36 +188,31 @@ public class QueryStructReq extends SemanticQueryReq {
|
||||
List<Aggregator> aggregators = queryStructReq.getAggregators();
|
||||
if (!CollectionUtils.isEmpty(aggregators)) {
|
||||
for (Aggregator aggregator : aggregators) {
|
||||
selectItems.add(buildAggregatorSelectItem(aggregator, queryStructReq));
|
||||
selectItems.add(buildAggregatorSelectItem(aggregator));
|
||||
}
|
||||
}
|
||||
|
||||
return selectItems;
|
||||
}
|
||||
|
||||
private SelectItem buildAggregatorSelectItem(Aggregator aggregator,
|
||||
QueryStructReq queryStructReq) {
|
||||
private SelectItem buildAggregatorSelectItem(Aggregator aggregator) {
|
||||
String columnName = aggregator.getColumn();
|
||||
if (queryStructReq.getQueryType().isNativeAggQuery()) {
|
||||
return new SelectItem(new Column(columnName));
|
||||
} else {
|
||||
Function function = new Function();
|
||||
AggOperatorEnum func = aggregator.getFunc();
|
||||
if (AggOperatorEnum.UNKNOWN.equals(func)) {
|
||||
func = AggOperatorEnum.SUM;
|
||||
}
|
||||
function.setName(func.getOperator());
|
||||
if (AggOperatorEnum.COUNT_DISTINCT.equals(func)) {
|
||||
function.setName("count");
|
||||
function.setDistinct(true);
|
||||
}
|
||||
function.setParameters(new ExpressionList(new Column(columnName)));
|
||||
SelectItem selectExpressionItem = new SelectItem(function);
|
||||
String alias = StringUtils.isNotBlank(aggregator.getAlias()) ? aggregator.getAlias()
|
||||
: columnName;
|
||||
selectExpressionItem.setAlias(new Alias(alias));
|
||||
return selectExpressionItem;
|
||||
Function function = new Function();
|
||||
AggOperatorEnum func = aggregator.getFunc();
|
||||
if (AggOperatorEnum.UNKNOWN.equals(func)) {
|
||||
func = AggOperatorEnum.SUM;
|
||||
}
|
||||
function.setName(func.getOperator());
|
||||
if (AggOperatorEnum.COUNT_DISTINCT.equals(func)) {
|
||||
function.setName("count");
|
||||
function.setDistinct(true);
|
||||
}
|
||||
function.setParameters(new ExpressionList(new Column(columnName)));
|
||||
SelectItem selectExpressionItem = new SelectItem(function);
|
||||
String alias =
|
||||
StringUtils.isNotBlank(aggregator.getAlias()) ? aggregator.getAlias() : columnName;
|
||||
selectExpressionItem.setAlias(new Alias(alias));
|
||||
return selectExpressionItem;
|
||||
}
|
||||
|
||||
private List<OrderByElement> buildOrderByElements(QueryStructReq queryStructReq) {
|
||||
@@ -241,7 +236,7 @@ public class QueryStructReq extends SemanticQueryReq {
|
||||
|
||||
private GroupByElement buildGroupByElement(QueryStructReq queryStructReq) {
|
||||
List<String> groups = queryStructReq.getGroups();
|
||||
if (!CollectionUtils.isEmpty(groups) && !queryStructReq.getQueryType().isNativeAggQuery()) {
|
||||
if (!CollectionUtils.isEmpty(groups) && !queryStructReq.getAggregators().isEmpty()) {
|
||||
GroupByElement groupByElement = new GroupByElement();
|
||||
for (String group : groups) {
|
||||
groupByElement.addGroupByExpression(new Column(group));
|
||||
|
||||
@@ -41,6 +41,10 @@ public class ChatQueryContext implements Serializable {
|
||||
}
|
||||
}
|
||||
|
||||
public DataSetSchema getDataSetSchema(Long dataSetId) {
|
||||
return semanticSchema.getDataSetSchema(dataSetId);
|
||||
}
|
||||
|
||||
public List<SemanticQuery> getCandidateQueries() {
|
||||
candidateQueries = candidateQueries.stream()
|
||||
.sorted(Comparator.comparing(
|
||||
|
||||
@@ -1,18 +1,10 @@
|
||||
package com.tencent.supersonic.headless.chat.parser;
|
||||
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
/** QueryTypeParser resolves query type as either AGGREGATE or DETAIL */
|
||||
@Slf4j
|
||||
@@ -20,34 +12,17 @@ public class QueryTypeParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(ChatQueryContext chatQueryContext) {
|
||||
chatQueryContext.getCandidateQueries().forEach(query -> {
|
||||
SemanticParseInfo parseInfo = query.getParseInfo();
|
||||
String s2SQL = parseInfo.getSqlInfo().getParsedS2SQL();
|
||||
QueryType queryType = QueryType.DETAIL;
|
||||
|
||||
List<SemanticQuery> candidateQueries = chatQueryContext.getCandidateQueries();
|
||||
User user = chatQueryContext.getRequest().getUser();
|
||||
if (SqlSelectFunctionHelper.hasAggregateFunction(s2SQL)) {
|
||||
queryType = QueryType.AGGREGATE;
|
||||
}
|
||||
|
||||
for (SemanticQuery semanticQuery : candidateQueries) {
|
||||
// 1.init S2SQL
|
||||
Long dataSetId = semanticQuery.getParseInfo().getDataSetId();
|
||||
DataSetSchema dataSetSchema =
|
||||
chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
||||
semanticQuery.initS2Sql(dataSetSchema, user);
|
||||
// 2.set queryType
|
||||
QueryType queryType = getQueryType(semanticQuery);
|
||||
semanticQuery.getParseInfo().setQueryType(queryType);
|
||||
}
|
||||
}
|
||||
|
||||
private QueryType getQueryType(SemanticQuery semanticQuery) {
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getParsedS2SQL())) {
|
||||
return QueryType.DETAIL;
|
||||
}
|
||||
|
||||
if (SqlSelectFunctionHelper.hasAggregateFunction(sqlInfo.getParsedS2SQL())) {
|
||||
return QueryType.AGGREGATE;
|
||||
}
|
||||
|
||||
return QueryType.DETAIL;
|
||||
parseInfo.setQueryType(queryType);
|
||||
});
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -49,6 +49,7 @@ public class LLMResponseService {
|
||||
parseInfo.setScore(queryCtx.getRequest().getQueryText().length() * (1 + weight));
|
||||
parseInfo.setQueryMode(semanticQuery.getQueryMode());
|
||||
parseInfo.getSqlInfo().setParsedS2SQL(s2SQL);
|
||||
parseInfo.getSqlInfo().setCorrectedS2SQL(s2SQL);
|
||||
queryCtx.getCandidateQueries().add(semanticQuery);
|
||||
}
|
||||
|
||||
|
||||
@@ -34,15 +34,13 @@ public class RuleSqlParser implements SemanticParser {
|
||||
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(dataSetId);
|
||||
List<RuleSemanticQuery> queries =
|
||||
RuleSemanticQuery.resolve(dataSetId, elementMatches, chatQueryContext);
|
||||
for (RuleSemanticQuery query : queries) {
|
||||
query.fillParseInfo(chatQueryContext);
|
||||
chatQueryContext.getCandidateQueries().add(query);
|
||||
}
|
||||
candidateQueries.addAll(chatQueryContext.getCandidateQueries());
|
||||
chatQueryContext.getCandidateQueries().clear();
|
||||
candidateQueries.addAll(queries);
|
||||
}
|
||||
chatQueryContext.setCandidateQueries(candidateQueries);
|
||||
|
||||
auxiliaryParsers.forEach(p -> p.parse(chatQueryContext));
|
||||
|
||||
candidateQueries.forEach(query -> query.buildS2Sql(
|
||||
chatQueryContext.getDataSetSchema(query.getParseInfo().getDataSetId())));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,87 +1,24 @@
|
||||
package com.tencent.supersonic.headless.chat.query;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Aggregator;
|
||||
import com.tencent.supersonic.common.pojo.Filter;
|
||||
import com.tencent.supersonic.common.pojo.Order;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
@ToString
|
||||
@Data
|
||||
public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
|
||||
|
||||
protected SemanticParseInfo parseInfo = new SemanticParseInfo();
|
||||
|
||||
@Override
|
||||
public SemanticParseInfo getParseInfo() {
|
||||
return parseInfo;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setParseInfo(SemanticParseInfo parseInfo) {
|
||||
this.parseInfo = parseInfo;
|
||||
}
|
||||
|
||||
protected QueryStructReq convertQueryStruct() {
|
||||
return QueryReqBuilder.buildStructReq(parseInfo);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SemanticQueryReq buildSemanticQueryReq() {
|
||||
return QueryReqBuilder.buildS2SQLReq(parseInfo.getSqlInfo(), parseInfo.getDataSetId());
|
||||
}
|
||||
|
||||
protected void initS2SqlByStruct(DataSetSchema dataSetSchema) {
|
||||
QueryStructReq queryStructReq = convertQueryStruct();
|
||||
convertBizNameToName(dataSetSchema, queryStructReq);
|
||||
QuerySqlReq querySQLReq = queryStructReq.convert();
|
||||
parseInfo.getSqlInfo().setParsedS2SQL(querySQLReq.getSql());
|
||||
parseInfo.getSqlInfo().setCorrectedS2SQL(querySQLReq.getSql());
|
||||
}
|
||||
|
||||
protected void convertBizNameToName(DataSetSchema dataSetSchema,
|
||||
QueryStructReq queryStructReq) {
|
||||
Map<String, String> bizNameToName = dataSetSchema.getBizNameToName();
|
||||
bizNameToName.putAll(TimeDimensionEnum.getNameToNameMap());
|
||||
|
||||
List<Order> orders = queryStructReq.getOrders();
|
||||
if (CollectionUtils.isNotEmpty(orders)) {
|
||||
for (Order order : orders) {
|
||||
order.setColumn(bizNameToName.get(order.getColumn()));
|
||||
}
|
||||
}
|
||||
List<Aggregator> aggregators = queryStructReq.getAggregators();
|
||||
if (CollectionUtils.isNotEmpty(aggregators)) {
|
||||
for (Aggregator aggregator : aggregators) {
|
||||
aggregator.setColumn(bizNameToName.get(aggregator.getColumn()));
|
||||
}
|
||||
}
|
||||
List<String> groups = queryStructReq.getGroups();
|
||||
if (CollectionUtils.isNotEmpty(groups)) {
|
||||
groups = groups.stream().map(bizNameToName::get).collect(Collectors.toList());
|
||||
queryStructReq.setGroups(groups);
|
||||
}
|
||||
List<Filter> dimensionFilters = queryStructReq.getDimensionFilters();
|
||||
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
|
||||
dimensionFilters
|
||||
.forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
|
||||
}
|
||||
List<Filter> metricFilters = queryStructReq.getMetricFilters();
|
||||
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
|
||||
metricFilters.forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.headless.chat.query;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
@@ -13,7 +12,7 @@ public interface SemanticQuery {
|
||||
|
||||
SemanticQueryReq buildSemanticQueryReq() throws SqlParseException;
|
||||
|
||||
void initS2Sql(DataSetSchema dataSetSchema, User user);
|
||||
void buildS2Sql(DataSetSchema dataSetSchema);
|
||||
|
||||
SemanticParseInfo getParseInfo();
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.headless.chat.query.llm.s2sql;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||
@@ -24,7 +23,7 @@ public class LLMSqlQuery extends LLMSemanticQuery {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initS2Sql(DataSetSchema dataSetSchema, User user) {
|
||||
public void buildS2Sql(DataSetSchema dataSetSchema) {
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
sqlInfo.setCorrectedS2SQL(sqlInfo.getParsedS2SQL());
|
||||
}
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package com.tencent.supersonic.headless.chat.query.rule;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.Aggregator;
|
||||
import com.tencent.supersonic.common.pojo.Filter;
|
||||
import com.tencent.supersonic.common.pojo.Order;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
@@ -10,6 +13,8 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.query.BaseSemanticQuery;
|
||||
@@ -17,6 +22,7 @@ import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
|
||||
import lombok.ToString;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
@@ -26,7 +32,6 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.TERM;
|
||||
@@ -50,14 +55,24 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initS2Sql(DataSetSchema dataSetSchema, User user) {
|
||||
initS2SqlByStruct(dataSetSchema);
|
||||
public void buildS2Sql(DataSetSchema dataSetSchema) {
|
||||
QueryStructReq queryStructReq = convertQueryStruct();
|
||||
convertBizNameToName(dataSetSchema, queryStructReq);
|
||||
QuerySqlReq querySQLReq = queryStructReq.convert();
|
||||
parseInfo.getSqlInfo().setParsedS2SQL(querySQLReq.getSql());
|
||||
parseInfo.getSqlInfo().setCorrectedS2SQL(querySQLReq.getSql());
|
||||
}
|
||||
|
||||
public void fillParseInfo(ChatQueryContext chatQueryContext) {
|
||||
parseInfo.setQueryMode(getQueryMode());
|
||||
protected QueryStructReq convertQueryStruct() {
|
||||
return QueryReqBuilder.buildStructReq(parseInfo);
|
||||
}
|
||||
|
||||
protected void fillParseInfo(ChatQueryContext chatQueryContext, Long dataSetId) {
|
||||
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||
|
||||
parseInfo.setQueryMode(getQueryMode());
|
||||
parseInfo.setDataSet(semanticSchema.getDataSet(dataSetId));
|
||||
parseInfo.setQueryConfig(semanticSchema.getQueryConfig(dataSetId));
|
||||
fillSchemaElement(parseInfo, semanticSchema);
|
||||
fillScore(parseInfo);
|
||||
fillDateConfByInherited(parseInfo, chatQueryContext);
|
||||
@@ -110,12 +125,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
}
|
||||
|
||||
private void fillSchemaElement(SemanticParseInfo parseInfo, SemanticSchema semanticSchema) {
|
||||
Set<Long> dataSetIds =
|
||||
parseInfo.getElementMatches().stream().map(SchemaElementMatch::getElement)
|
||||
.map(SchemaElement::getDataSetId).collect(Collectors.toSet());
|
||||
Long dataSetId = dataSetIds.iterator().next();
|
||||
parseInfo.setDataSet(semanticSchema.getDataSet(dataSetId));
|
||||
parseInfo.setQueryConfig(semanticSchema.getQueryConfig(dataSetId));
|
||||
|
||||
Map<Long, List<SchemaElementMatch>> dim2Values = new HashMap<>();
|
||||
|
||||
for (SchemaElementMatch schemaMatch : parseInfo.getElementMatches()) {
|
||||
@@ -200,14 +210,15 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
public static List<RuleSemanticQuery> resolve(Long dataSetId,
|
||||
List<SchemaElementMatch> candidateElementMatches, ChatQueryContext chatQueryContext) {
|
||||
List<RuleSemanticQuery> matchedQueries = new ArrayList<>();
|
||||
|
||||
for (RuleSemanticQuery semanticQuery : QueryManager.getRuleQueries()) {
|
||||
List<SchemaElementMatch> matches =
|
||||
semanticQuery.match(candidateElementMatches, chatQueryContext);
|
||||
|
||||
if (!matches.isEmpty()) {
|
||||
RuleSemanticQuery query =
|
||||
QueryManager.createRuleQuery(semanticQuery.getQueryMode());
|
||||
query.getParseInfo().getElementMatches().addAll(matches);
|
||||
query.fillParseInfo(chatQueryContext, dataSetId);
|
||||
matchedQueries.add(query);
|
||||
}
|
||||
}
|
||||
@@ -217,4 +228,39 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
protected QueryMultiStructReq convertQueryMultiStruct() {
|
||||
return QueryReqBuilder.buildMultiStructReq(parseInfo);
|
||||
}
|
||||
|
||||
|
||||
protected void convertBizNameToName(DataSetSchema dataSetSchema,
|
||||
QueryStructReq queryStructReq) {
|
||||
Map<String, String> bizNameToName = dataSetSchema.getBizNameToName();
|
||||
bizNameToName.putAll(TimeDimensionEnum.getNameToNameMap());
|
||||
|
||||
List<Order> orders = queryStructReq.getOrders();
|
||||
if (CollectionUtils.isNotEmpty(orders)) {
|
||||
for (Order order : orders) {
|
||||
order.setColumn(bizNameToName.get(order.getColumn()));
|
||||
}
|
||||
}
|
||||
List<Aggregator> aggregators = queryStructReq.getAggregators();
|
||||
if (CollectionUtils.isNotEmpty(aggregators)) {
|
||||
for (Aggregator aggregator : aggregators) {
|
||||
aggregator.setColumn(bizNameToName.get(aggregator.getColumn()));
|
||||
}
|
||||
}
|
||||
List<String> groups = queryStructReq.getGroups();
|
||||
if (CollectionUtils.isNotEmpty(groups)) {
|
||||
groups = groups.stream().map(bizNameToName::get).collect(Collectors.toList());
|
||||
queryStructReq.setGroups(groups);
|
||||
}
|
||||
List<Filter> dimensionFilters = queryStructReq.getDimensionFilters();
|
||||
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
|
||||
dimensionFilters
|
||||
.forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
|
||||
}
|
||||
List<Filter> metricFilters = queryStructReq.getMetricFilters();
|
||||
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
|
||||
metricFilters.forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package com.tencent.supersonic.headless.chat.query.rule.detail;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeMode;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
@@ -29,12 +28,10 @@ public abstract class DetailSemanticQuery extends RuleSemanticQuery {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fillParseInfo(ChatQueryContext chatQueryContext) {
|
||||
super.fillParseInfo(chatQueryContext);
|
||||
public void fillParseInfo(ChatQueryContext chatQueryContext, Long dataSetId) {
|
||||
super.fillParseInfo(chatQueryContext, dataSetId);
|
||||
|
||||
parseInfo.setQueryType(QueryType.DETAIL);
|
||||
parseInfo.setLimit(parseInfo.getDetailLimit());
|
||||
|
||||
if (!needFillDateConf(chatQueryContext)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.headless.chat.query.rule.detail;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@@ -31,16 +31,14 @@ public class DetailValueQuery extends DetailSemanticQuery {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fillParseInfo(ChatQueryContext chatQueryContext) {
|
||||
super.fillParseInfo(chatQueryContext);
|
||||
public void fillParseInfo(ChatQueryContext chatQueryContext, Long dataSetId) {
|
||||
super.fillParseInfo(chatQueryContext, dataSetId);
|
||||
|
||||
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||
parseInfo.getDimensions().addAll(semanticSchema.getDimensions());
|
||||
parseInfo.getDimensions().forEach(d -> {
|
||||
parseInfo.getElementMatches()
|
||||
.add(SchemaElementMatch.builder().element(d).word(d.getName()).similarity(0)
|
||||
.isInherited(false).detectWord(d.getName()).build());
|
||||
});
|
||||
DataSetSchema dataSetSchema = chatQueryContext.getDataSetSchema(dataSetId);
|
||||
parseInfo.getDimensions().addAll(dataSetSchema.getDimensions());
|
||||
parseInfo.getDimensions().forEach(
|
||||
d -> parseInfo.getElementMatches().add(SchemaElementMatch.builder().element(d)
|
||||
.word(d.getName()).similarity(0).detectWord(d.getName()).build()));
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -32,8 +32,9 @@ public abstract class MetricSemanticQuery extends RuleSemanticQuery {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fillParseInfo(ChatQueryContext chatQueryContext) {
|
||||
super.fillParseInfo(chatQueryContext);
|
||||
public void fillParseInfo(ChatQueryContext chatQueryContext, Long dataSetId) {
|
||||
super.fillParseInfo(chatQueryContext, dataSetId);
|
||||
|
||||
parseInfo.setLimit(parseInfo.getMetricLimit());
|
||||
fillDateInfo(chatQueryContext);
|
||||
}
|
||||
|
||||
@@ -48,8 +48,8 @@ public class MetricTopNQuery extends MetricSemanticQuery {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fillParseInfo(ChatQueryContext chatQueryContext) {
|
||||
super.fillParseInfo(chatQueryContext);
|
||||
public void fillParseInfo(ChatQueryContext chatQueryContext, Long dataSetId) {
|
||||
super.fillParseInfo(chatQueryContext, dataSetId);
|
||||
|
||||
parseInfo.setScore(parseInfo.getScore() + 2.0);
|
||||
parseInfo.setAggType(AggregateTypeEnum.SUM);
|
||||
|
||||
@@ -17,14 +17,12 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.LinkedHashSet;
|
||||
@@ -51,12 +49,6 @@ public class QueryReqBuilder {
|
||||
chatFilter.getOperator(), chatFilter.getValue()))
|
||||
.collect(Collectors.toList());
|
||||
queryStructReq.setMetricFilters(metricFilters);
|
||||
|
||||
addDateDimension(parseInfo);
|
||||
|
||||
if (isDateFieldAlreadyPresent(parseInfo, getDateField(parseInfo.getDateInfo()))) {
|
||||
parseInfo.getDimensions().removeIf(schemaElement -> schemaElement.isPartitionTime());
|
||||
}
|
||||
queryStructReq.setGroups(parseInfo.getDimensions().stream().map(SchemaElement::getBizName)
|
||||
.collect(Collectors.toList()));
|
||||
queryStructReq.setLimit(parseInfo.getLimit());
|
||||
@@ -155,51 +147,6 @@ public class QueryReqBuilder {
|
||||
return aggregateType.name();
|
||||
}
|
||||
|
||||
private static void addDateDimension(SemanticParseInfo parseInfo) {
|
||||
if (parseInfo == null || parseInfo.getDateInfo() == null) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (shouldSkipAddingDateDimension(parseInfo)) {
|
||||
return;
|
||||
}
|
||||
|
||||
String dateField = getDateField(parseInfo.getDateInfo());
|
||||
if (isDateFieldAlreadyPresent(parseInfo, dateField)) {
|
||||
return;
|
||||
}
|
||||
|
||||
SchemaElement dimension = new SchemaElement();
|
||||
dimension.setBizName(dateField);
|
||||
|
||||
if (QueryManager.isMetricQuery(parseInfo.getQueryMode())) {
|
||||
addDimension(parseInfo, dimension);
|
||||
}
|
||||
}
|
||||
|
||||
private static boolean shouldSkipAddingDateDimension(SemanticParseInfo parseInfo) {
|
||||
return parseInfo.getAggType() != null
|
||||
&& (parseInfo.getAggType().equals(AggregateTypeEnum.MAX)
|
||||
|| parseInfo.getAggType().equals(AggregateTypeEnum.MIN))
|
||||
&& !CollectionUtils.isEmpty(parseInfo.getDimensions());
|
||||
}
|
||||
|
||||
private static boolean isDateFieldAlreadyPresent(SemanticParseInfo parseInfo,
|
||||
String dateField) {
|
||||
return parseInfo.getDimensions().stream()
|
||||
.anyMatch(dimension -> dimension.getBizName().equalsIgnoreCase(dateField));
|
||||
}
|
||||
|
||||
private static void addDimension(SemanticParseInfo parseInfo, SchemaElement dimension) {
|
||||
List<String> timeDimensions = Arrays.asList(TimeDimensionEnum.DAY.getName(),
|
||||
TimeDimensionEnum.WEEK.getName(), TimeDimensionEnum.MONTH.getName());
|
||||
Set<SchemaElement> dimensions = parseInfo.getDimensions().stream()
|
||||
.filter(d -> !timeDimensions.contains(d.getBizName().toLowerCase()))
|
||||
.collect(Collectors.toSet());
|
||||
dimensions.add(dimension);
|
||||
parseInfo.setDimensions(dimensions);
|
||||
}
|
||||
|
||||
public static Set<Order> getOrder(Set<Order> existingOrders, AggregateTypeEnum aggregator,
|
||||
SchemaElement metric) {
|
||||
if (existingOrders != null && !existingOrders.isEmpty()) {
|
||||
|
||||
Reference in New Issue
Block a user