diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java index 113cf1140..9c95d1fc3 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java @@ -231,7 +231,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { log.info("rule begin replace metrics and revise filters!"); validFilter(semanticQuery.getParseInfo().getDimensionFilters()); validFilter(semanticQuery.getParseInfo().getMetricFilters()); - semanticQuery.initS2Sql(dataSetSchema, user); + semanticQuery.buildS2Sql(dataSetSchema); } private QueryResult executeQuery(SemanticQuery semanticQuery, User user) throws Exception { diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DataSetSchema.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DataSetSchema.java index 10aee9bdc..7d0173f77 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DataSetSchema.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DataSetSchema.java @@ -24,6 +24,10 @@ public class DataSetSchema implements Serializable { private Set terms = new HashSet<>(); private QueryConfig queryConfig; + public Long getDataSetId() { + return dataSet.getDataSetId(); + } + public SchemaElement getElement(SchemaElementType elementType, long elementID) { Optional element = Optional.empty(); diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticSchema.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticSchema.java index e19153c1c..e20212e1e 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticSchema.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticSchema.java @@ -119,22 +119,26 @@ public class SemanticSchema implements Serializable { return getElementsById(dataSetId, dataSets).orElse(null); } - public QueryConfig getQueryConfig(Long dataSetId) { - DataSetSchema first = dataSetSchemaList.stream().filter( - dataSetSchema -> dataSetId.equals(dataSetSchema.getDataSet().getDataSetId())) - .findFirst().orElse(null); - if (Objects.nonNull(first)) { - return first.getQueryConfig(); - } - return null; - } - public List getDataSets() { List dataSets = new ArrayList<>(); dataSetSchemaList.forEach(d -> dataSets.add(d.getDataSet())); return dataSets; } + public DataSetSchema getDataSetSchema(Long dataSetId) { + return dataSetSchemaList.stream() + .filter(dataSetSchema -> dataSetId.equals(dataSetSchema.getDataSetId())).findFirst() + .orElse(null); + } + + public QueryConfig getQueryConfig(Long dataSetId) { + DataSetSchema dataSetSchema = getDataSetSchema(dataSetId); + if (Objects.nonNull(dataSetSchema)) { + return dataSetSchema.getQueryConfig(); + } + return null; + } + public Map getDataSetSchemaMap() { if (CollectionUtils.isEmpty(dataSetSchemaList)) { return new HashMap<>(); diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryStructReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryStructReq.java index 9ddc4e0d7..752f193c6 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryStructReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryStructReq.java @@ -188,36 +188,31 @@ public class QueryStructReq extends SemanticQueryReq { List aggregators = queryStructReq.getAggregators(); if (!CollectionUtils.isEmpty(aggregators)) { for (Aggregator aggregator : aggregators) { - selectItems.add(buildAggregatorSelectItem(aggregator, queryStructReq)); + selectItems.add(buildAggregatorSelectItem(aggregator)); } } return selectItems; } - private SelectItem buildAggregatorSelectItem(Aggregator aggregator, - QueryStructReq queryStructReq) { + private SelectItem buildAggregatorSelectItem(Aggregator aggregator) { String columnName = aggregator.getColumn(); - if (queryStructReq.getQueryType().isNativeAggQuery()) { - return new SelectItem(new Column(columnName)); - } else { - Function function = new Function(); - AggOperatorEnum func = aggregator.getFunc(); - if (AggOperatorEnum.UNKNOWN.equals(func)) { - func = AggOperatorEnum.SUM; - } - function.setName(func.getOperator()); - if (AggOperatorEnum.COUNT_DISTINCT.equals(func)) { - function.setName("count"); - function.setDistinct(true); - } - function.setParameters(new ExpressionList(new Column(columnName))); - SelectItem selectExpressionItem = new SelectItem(function); - String alias = StringUtils.isNotBlank(aggregator.getAlias()) ? aggregator.getAlias() - : columnName; - selectExpressionItem.setAlias(new Alias(alias)); - return selectExpressionItem; + Function function = new Function(); + AggOperatorEnum func = aggregator.getFunc(); + if (AggOperatorEnum.UNKNOWN.equals(func)) { + func = AggOperatorEnum.SUM; } + function.setName(func.getOperator()); + if (AggOperatorEnum.COUNT_DISTINCT.equals(func)) { + function.setName("count"); + function.setDistinct(true); + } + function.setParameters(new ExpressionList(new Column(columnName))); + SelectItem selectExpressionItem = new SelectItem(function); + String alias = + StringUtils.isNotBlank(aggregator.getAlias()) ? aggregator.getAlias() : columnName; + selectExpressionItem.setAlias(new Alias(alias)); + return selectExpressionItem; } private List buildOrderByElements(QueryStructReq queryStructReq) { @@ -241,7 +236,7 @@ public class QueryStructReq extends SemanticQueryReq { private GroupByElement buildGroupByElement(QueryStructReq queryStructReq) { List groups = queryStructReq.getGroups(); - if (!CollectionUtils.isEmpty(groups) && !queryStructReq.getQueryType().isNativeAggQuery()) { + if (!CollectionUtils.isEmpty(groups) && !queryStructReq.getAggregators().isEmpty()) { GroupByElement groupByElement = new GroupByElement(); for (String group : groups) { groupByElement.addGroupByExpression(new Column(group)); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java index a55e0c409..b1167c78b 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java @@ -41,6 +41,10 @@ public class ChatQueryContext implements Serializable { } } + public DataSetSchema getDataSetSchema(Long dataSetId) { + return semanticSchema.getDataSetSchema(dataSetId); + } + public List getCandidateQueries() { candidateQueries = candidateQueries.stream() .sorted(Comparator.comparing( diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java index 4ea5d4e7b..c11182988 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java @@ -1,18 +1,10 @@ package com.tencent.supersonic.headless.chat.parser; import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper; -import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.common.pojo.enums.QueryType; -import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.headless.api.pojo.SqlInfo; import com.tencent.supersonic.headless.chat.ChatQueryContext; -import com.tencent.supersonic.headless.chat.query.SemanticQuery; import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.StringUtils; - -import java.util.List; -import java.util.Objects; /** QueryTypeParser resolves query type as either AGGREGATE or DETAIL */ @Slf4j @@ -20,34 +12,17 @@ public class QueryTypeParser implements SemanticParser { @Override public void parse(ChatQueryContext chatQueryContext) { + chatQueryContext.getCandidateQueries().forEach(query -> { + SemanticParseInfo parseInfo = query.getParseInfo(); + String s2SQL = parseInfo.getSqlInfo().getParsedS2SQL(); + QueryType queryType = QueryType.DETAIL; - List candidateQueries = chatQueryContext.getCandidateQueries(); - User user = chatQueryContext.getRequest().getUser(); + if (SqlSelectFunctionHelper.hasAggregateFunction(s2SQL)) { + queryType = QueryType.AGGREGATE; + } - for (SemanticQuery semanticQuery : candidateQueries) { - // 1.init S2SQL - Long dataSetId = semanticQuery.getParseInfo().getDataSetId(); - DataSetSchema dataSetSchema = - chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId); - semanticQuery.initS2Sql(dataSetSchema, user); - // 2.set queryType - QueryType queryType = getQueryType(semanticQuery); - semanticQuery.getParseInfo().setQueryType(queryType); - } - } - - private QueryType getQueryType(SemanticQuery semanticQuery) { - SemanticParseInfo parseInfo = semanticQuery.getParseInfo(); - SqlInfo sqlInfo = parseInfo.getSqlInfo(); - if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getParsedS2SQL())) { - return QueryType.DETAIL; - } - - if (SqlSelectFunctionHelper.hasAggregateFunction(sqlInfo.getParsedS2SQL())) { - return QueryType.AGGREGATE; - } - - return QueryType.DETAIL; + parseInfo.setQueryType(queryType); + }); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java index 9b0aeb554..59b1d8560 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java @@ -49,6 +49,7 @@ public class LLMResponseService { parseInfo.setScore(queryCtx.getRequest().getQueryText().length() * (1 + weight)); parseInfo.setQueryMode(semanticQuery.getQueryMode()); parseInfo.getSqlInfo().setParsedS2SQL(s2SQL); + parseInfo.getSqlInfo().setCorrectedS2SQL(s2SQL); queryCtx.getCandidateQueries().add(semanticQuery); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java index 34a93975c..bb3e65305 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java @@ -34,15 +34,13 @@ public class RuleSqlParser implements SemanticParser { List elementMatches = mapInfo.getMatchedElements(dataSetId); List queries = RuleSemanticQuery.resolve(dataSetId, elementMatches, chatQueryContext); - for (RuleSemanticQuery query : queries) { - query.fillParseInfo(chatQueryContext); - chatQueryContext.getCandidateQueries().add(query); - } - candidateQueries.addAll(chatQueryContext.getCandidateQueries()); - chatQueryContext.getCandidateQueries().clear(); + candidateQueries.addAll(queries); } chatQueryContext.setCandidateQueries(candidateQueries); auxiliaryParsers.forEach(p -> p.parse(chatQueryContext)); + + candidateQueries.forEach(query -> query.buildS2Sql( + chatQueryContext.getDataSetSchema(query.getParseInfo().getDataSetId()))); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/BaseSemanticQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/BaseSemanticQuery.java index a6ce039ed..714cff66a 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/BaseSemanticQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/BaseSemanticQuery.java @@ -1,87 +1,24 @@ package com.tencent.supersonic.headless.chat.query; -import com.tencent.supersonic.common.pojo.Aggregator; -import com.tencent.supersonic.common.pojo.Filter; -import com.tencent.supersonic.common.pojo.Order; -import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; -import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; -import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder; +import lombok.Data; import lombok.ToString; import lombok.extern.slf4j.Slf4j; -import org.apache.commons.collections.CollectionUtils; import java.io.Serializable; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; @Slf4j @ToString +@Data public abstract class BaseSemanticQuery implements SemanticQuery, Serializable { protected SemanticParseInfo parseInfo = new SemanticParseInfo(); - @Override - public SemanticParseInfo getParseInfo() { - return parseInfo; - } - - @Override - public void setParseInfo(SemanticParseInfo parseInfo) { - this.parseInfo = parseInfo; - } - - protected QueryStructReq convertQueryStruct() { - return QueryReqBuilder.buildStructReq(parseInfo); - } - @Override public SemanticQueryReq buildSemanticQueryReq() { return QueryReqBuilder.buildS2SQLReq(parseInfo.getSqlInfo(), parseInfo.getDataSetId()); } - protected void initS2SqlByStruct(DataSetSchema dataSetSchema) { - QueryStructReq queryStructReq = convertQueryStruct(); - convertBizNameToName(dataSetSchema, queryStructReq); - QuerySqlReq querySQLReq = queryStructReq.convert(); - parseInfo.getSqlInfo().setParsedS2SQL(querySQLReq.getSql()); - parseInfo.getSqlInfo().setCorrectedS2SQL(querySQLReq.getSql()); - } - - protected void convertBizNameToName(DataSetSchema dataSetSchema, - QueryStructReq queryStructReq) { - Map bizNameToName = dataSetSchema.getBizNameToName(); - bizNameToName.putAll(TimeDimensionEnum.getNameToNameMap()); - - List orders = queryStructReq.getOrders(); - if (CollectionUtils.isNotEmpty(orders)) { - for (Order order : orders) { - order.setColumn(bizNameToName.get(order.getColumn())); - } - } - List aggregators = queryStructReq.getAggregators(); - if (CollectionUtils.isNotEmpty(aggregators)) { - for (Aggregator aggregator : aggregators) { - aggregator.setColumn(bizNameToName.get(aggregator.getColumn())); - } - } - List groups = queryStructReq.getGroups(); - if (CollectionUtils.isNotEmpty(groups)) { - groups = groups.stream().map(bizNameToName::get).collect(Collectors.toList()); - queryStructReq.setGroups(groups); - } - List dimensionFilters = queryStructReq.getDimensionFilters(); - if (CollectionUtils.isNotEmpty(dimensionFilters)) { - dimensionFilters - .forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName()))); - } - List metricFilters = queryStructReq.getMetricFilters(); - if (CollectionUtils.isNotEmpty(dimensionFilters)) { - metricFilters.forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName()))); - } - } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/SemanticQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/SemanticQuery.java index e1025bdb7..1a8d0448d 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/SemanticQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/SemanticQuery.java @@ -1,6 +1,5 @@ package com.tencent.supersonic.headless.chat.query; -import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; @@ -13,7 +12,7 @@ public interface SemanticQuery { SemanticQueryReq buildSemanticQueryReq() throws SqlParseException; - void initS2Sql(DataSetSchema dataSetSchema, User user); + void buildS2Sql(DataSetSchema dataSetSchema); SemanticParseInfo getParseInfo(); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMSqlQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMSqlQuery.java index 8ba599fe9..775bb4202 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMSqlQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMSqlQuery.java @@ -1,6 +1,5 @@ package com.tencent.supersonic.headless.chat.query.llm.s2sql; -import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.SqlInfo; import com.tencent.supersonic.headless.chat.query.QueryManager; @@ -24,7 +23,7 @@ public class LLMSqlQuery extends LLMSemanticQuery { } @Override - public void initS2Sql(DataSetSchema dataSetSchema, User user) { + public void buildS2Sql(DataSetSchema dataSetSchema) { SqlInfo sqlInfo = parseInfo.getSqlInfo(); sqlInfo.setCorrectedS2SQL(sqlInfo.getParsedS2SQL()); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java index b4ddb8db8..a2c0895ca 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java @@ -1,7 +1,10 @@ package com.tencent.supersonic.headless.chat.query.rule; -import com.tencent.supersonic.common.pojo.User; +import com.tencent.supersonic.common.pojo.Aggregator; +import com.tencent.supersonic.common.pojo.Filter; +import com.tencent.supersonic.common.pojo.Order; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; +import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; @@ -10,6 +13,8 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq; +import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; +import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.query.BaseSemanticQuery; @@ -17,6 +22,7 @@ import com.tencent.supersonic.headless.chat.query.QueryManager; import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder; import lombok.ToString; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.StringUtils; import java.util.ArrayList; @@ -26,7 +32,6 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Objects; -import java.util.Set; import java.util.stream.Collectors; import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.TERM; @@ -50,14 +55,24 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery { } @Override - public void initS2Sql(DataSetSchema dataSetSchema, User user) { - initS2SqlByStruct(dataSetSchema); + public void buildS2Sql(DataSetSchema dataSetSchema) { + QueryStructReq queryStructReq = convertQueryStruct(); + convertBizNameToName(dataSetSchema, queryStructReq); + QuerySqlReq querySQLReq = queryStructReq.convert(); + parseInfo.getSqlInfo().setParsedS2SQL(querySQLReq.getSql()); + parseInfo.getSqlInfo().setCorrectedS2SQL(querySQLReq.getSql()); } - public void fillParseInfo(ChatQueryContext chatQueryContext) { - parseInfo.setQueryMode(getQueryMode()); + protected QueryStructReq convertQueryStruct() { + return QueryReqBuilder.buildStructReq(parseInfo); + } + + protected void fillParseInfo(ChatQueryContext chatQueryContext, Long dataSetId) { SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema(); + parseInfo.setQueryMode(getQueryMode()); + parseInfo.setDataSet(semanticSchema.getDataSet(dataSetId)); + parseInfo.setQueryConfig(semanticSchema.getQueryConfig(dataSetId)); fillSchemaElement(parseInfo, semanticSchema); fillScore(parseInfo); fillDateConfByInherited(parseInfo, chatQueryContext); @@ -110,12 +125,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery { } private void fillSchemaElement(SemanticParseInfo parseInfo, SemanticSchema semanticSchema) { - Set dataSetIds = - parseInfo.getElementMatches().stream().map(SchemaElementMatch::getElement) - .map(SchemaElement::getDataSetId).collect(Collectors.toSet()); - Long dataSetId = dataSetIds.iterator().next(); - parseInfo.setDataSet(semanticSchema.getDataSet(dataSetId)); - parseInfo.setQueryConfig(semanticSchema.getQueryConfig(dataSetId)); + Map> dim2Values = new HashMap<>(); for (SchemaElementMatch schemaMatch : parseInfo.getElementMatches()) { @@ -200,14 +210,15 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery { public static List resolve(Long dataSetId, List candidateElementMatches, ChatQueryContext chatQueryContext) { List matchedQueries = new ArrayList<>(); + for (RuleSemanticQuery semanticQuery : QueryManager.getRuleQueries()) { List matches = semanticQuery.match(candidateElementMatches, chatQueryContext); - if (!matches.isEmpty()) { RuleSemanticQuery query = QueryManager.createRuleQuery(semanticQuery.getQueryMode()); query.getParseInfo().getElementMatches().addAll(matches); + query.fillParseInfo(chatQueryContext, dataSetId); matchedQueries.add(query); } } @@ -217,4 +228,39 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery { protected QueryMultiStructReq convertQueryMultiStruct() { return QueryReqBuilder.buildMultiStructReq(parseInfo); } + + + protected void convertBizNameToName(DataSetSchema dataSetSchema, + QueryStructReq queryStructReq) { + Map bizNameToName = dataSetSchema.getBizNameToName(); + bizNameToName.putAll(TimeDimensionEnum.getNameToNameMap()); + + List orders = queryStructReq.getOrders(); + if (CollectionUtils.isNotEmpty(orders)) { + for (Order order : orders) { + order.setColumn(bizNameToName.get(order.getColumn())); + } + } + List aggregators = queryStructReq.getAggregators(); + if (CollectionUtils.isNotEmpty(aggregators)) { + for (Aggregator aggregator : aggregators) { + aggregator.setColumn(bizNameToName.get(aggregator.getColumn())); + } + } + List groups = queryStructReq.getGroups(); + if (CollectionUtils.isNotEmpty(groups)) { + groups = groups.stream().map(bizNameToName::get).collect(Collectors.toList()); + queryStructReq.setGroups(groups); + } + List dimensionFilters = queryStructReq.getDimensionFilters(); + if (CollectionUtils.isNotEmpty(dimensionFilters)) { + dimensionFilters + .forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName()))); + } + List metricFilters = queryStructReq.getMetricFilters(); + if (CollectionUtils.isNotEmpty(dimensionFilters)) { + metricFilters.forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName()))); + } + } + } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/detail/DetailSemanticQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/detail/DetailSemanticQuery.java index 0bb9acb61..4ad32b274 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/detail/DetailSemanticQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/detail/DetailSemanticQuery.java @@ -1,7 +1,6 @@ package com.tencent.supersonic.headless.chat.query.rule.detail; import com.tencent.supersonic.common.pojo.DateConf; -import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.pojo.enums.TimeMode; import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; @@ -29,12 +28,10 @@ public abstract class DetailSemanticQuery extends RuleSemanticQuery { } @Override - public void fillParseInfo(ChatQueryContext chatQueryContext) { - super.fillParseInfo(chatQueryContext); + public void fillParseInfo(ChatQueryContext chatQueryContext, Long dataSetId) { + super.fillParseInfo(chatQueryContext, dataSetId); - parseInfo.setQueryType(QueryType.DETAIL); parseInfo.setLimit(parseInfo.getDetailLimit()); - if (!needFillDateConf(chatQueryContext)) { return; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/detail/DetailValueQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/detail/DetailValueQuery.java index 0bb27d61d..6f4cd5478 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/detail/DetailValueQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/detail/DetailValueQuery.java @@ -1,7 +1,7 @@ package com.tencent.supersonic.headless.chat.query.rule.detail; +import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; -import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.chat.ChatQueryContext; import org.springframework.stereotype.Component; @@ -31,16 +31,14 @@ public class DetailValueQuery extends DetailSemanticQuery { } @Override - public void fillParseInfo(ChatQueryContext chatQueryContext) { - super.fillParseInfo(chatQueryContext); + public void fillParseInfo(ChatQueryContext chatQueryContext, Long dataSetId) { + super.fillParseInfo(chatQueryContext, dataSetId); - SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema(); - parseInfo.getDimensions().addAll(semanticSchema.getDimensions()); - parseInfo.getDimensions().forEach(d -> { - parseInfo.getElementMatches() - .add(SchemaElementMatch.builder().element(d).word(d.getName()).similarity(0) - .isInherited(false).detectWord(d.getName()).build()); - }); + DataSetSchema dataSetSchema = chatQueryContext.getDataSetSchema(dataSetId); + parseInfo.getDimensions().addAll(dataSetSchema.getDimensions()); + parseInfo.getDimensions().forEach( + d -> parseInfo.getElementMatches().add(SchemaElementMatch.builder().element(d) + .word(d.getName()).similarity(0).detectWord(d.getName()).build())); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricSemanticQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricSemanticQuery.java index cbd482cf1..be4767cab 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricSemanticQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricSemanticQuery.java @@ -32,8 +32,9 @@ public abstract class MetricSemanticQuery extends RuleSemanticQuery { } @Override - public void fillParseInfo(ChatQueryContext chatQueryContext) { - super.fillParseInfo(chatQueryContext); + public void fillParseInfo(ChatQueryContext chatQueryContext, Long dataSetId) { + super.fillParseInfo(chatQueryContext, dataSetId); + parseInfo.setLimit(parseInfo.getMetricLimit()); fillDateInfo(chatQueryContext); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricTopNQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricTopNQuery.java index 10f2daffc..6349ed1e7 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricTopNQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricTopNQuery.java @@ -48,8 +48,8 @@ public class MetricTopNQuery extends MetricSemanticQuery { } @Override - public void fillParseInfo(ChatQueryContext chatQueryContext) { - super.fillParseInfo(chatQueryContext); + public void fillParseInfo(ChatQueryContext chatQueryContext, Long dataSetId) { + super.fillParseInfo(chatQueryContext, dataSetId); parseInfo.setScore(parseInfo.getScore() + 2.0); parseInfo.setAggType(AggregateTypeEnum.SUM); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/QueryReqBuilder.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/QueryReqBuilder.java index cb1dd57b3..a1b6f45ea 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/QueryReqBuilder.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/QueryReqBuilder.java @@ -17,14 +17,12 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq; import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq; -import com.tencent.supersonic.headless.chat.query.QueryManager; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.BeanUtils; import org.springframework.util.CollectionUtils; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.LinkedHashSet; @@ -51,12 +49,6 @@ public class QueryReqBuilder { chatFilter.getOperator(), chatFilter.getValue())) .collect(Collectors.toList()); queryStructReq.setMetricFilters(metricFilters); - - addDateDimension(parseInfo); - - if (isDateFieldAlreadyPresent(parseInfo, getDateField(parseInfo.getDateInfo()))) { - parseInfo.getDimensions().removeIf(schemaElement -> schemaElement.isPartitionTime()); - } queryStructReq.setGroups(parseInfo.getDimensions().stream().map(SchemaElement::getBizName) .collect(Collectors.toList())); queryStructReq.setLimit(parseInfo.getLimit()); @@ -155,51 +147,6 @@ public class QueryReqBuilder { return aggregateType.name(); } - private static void addDateDimension(SemanticParseInfo parseInfo) { - if (parseInfo == null || parseInfo.getDateInfo() == null) { - return; - } - - if (shouldSkipAddingDateDimension(parseInfo)) { - return; - } - - String dateField = getDateField(parseInfo.getDateInfo()); - if (isDateFieldAlreadyPresent(parseInfo, dateField)) { - return; - } - - SchemaElement dimension = new SchemaElement(); - dimension.setBizName(dateField); - - if (QueryManager.isMetricQuery(parseInfo.getQueryMode())) { - addDimension(parseInfo, dimension); - } - } - - private static boolean shouldSkipAddingDateDimension(SemanticParseInfo parseInfo) { - return parseInfo.getAggType() != null - && (parseInfo.getAggType().equals(AggregateTypeEnum.MAX) - || parseInfo.getAggType().equals(AggregateTypeEnum.MIN)) - && !CollectionUtils.isEmpty(parseInfo.getDimensions()); - } - - private static boolean isDateFieldAlreadyPresent(SemanticParseInfo parseInfo, - String dateField) { - return parseInfo.getDimensions().stream() - .anyMatch(dimension -> dimension.getBizName().equalsIgnoreCase(dateField)); - } - - private static void addDimension(SemanticParseInfo parseInfo, SchemaElement dimension) { - List timeDimensions = Arrays.asList(TimeDimensionEnum.DAY.getName(), - TimeDimensionEnum.WEEK.getName(), TimeDimensionEnum.MONTH.getName()); - Set dimensions = parseInfo.getDimensions().stream() - .filter(d -> !timeDimensions.contains(d.getBizName().toLowerCase())) - .collect(Collectors.toSet()); - dimensions.add(dimension); - parseInfo.setDimensions(dimensions); - } - public static Set getOrder(Set existingOrders, AggregateTypeEnum aggregator, SchemaElement metric) { if (existingOrders != null && !existingOrders.isEmpty()) { diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java index 624ae8a1a..430570a9a 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java @@ -18,7 +18,13 @@ import com.tencent.supersonic.chat.server.plugin.build.webservice.WebServiceQuer import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.JoinCondition; import com.tencent.supersonic.common.pojo.ModelRela; -import com.tencent.supersonic.common.pojo.enums.*; +import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; +import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum; +import com.tencent.supersonic.common.pojo.enums.AppModule; +import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; +import com.tencent.supersonic.common.pojo.enums.SensitiveLevelEnum; +import com.tencent.supersonic.common.pojo.enums.StatusEnum; +import com.tencent.supersonic.common.pojo.enums.TypeEnums; import com.tencent.supersonic.common.util.ChatAppManager; import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.headless.api.pojo.DataSetDetail; @@ -40,7 +46,15 @@ import com.tencent.supersonic.headless.api.pojo.enums.IdentifyType; import com.tencent.supersonic.headless.api.pojo.enums.MetricDefineType; import com.tencent.supersonic.headless.api.pojo.enums.SemanticType; import com.tencent.supersonic.headless.api.pojo.enums.TagDefineType; -import com.tencent.supersonic.headless.api.pojo.request.*; +import com.tencent.supersonic.headless.api.pojo.request.DataSetReq; +import com.tencent.supersonic.headless.api.pojo.request.DictItemReq; +import com.tencent.supersonic.headless.api.pojo.request.DictSingleTaskReq; +import com.tencent.supersonic.headless.api.pojo.request.DimensionReq; +import com.tencent.supersonic.headless.api.pojo.request.DomainReq; +import com.tencent.supersonic.headless.api.pojo.request.MetricReq; +import com.tencent.supersonic.headless.api.pojo.request.ModelReq; +import com.tencent.supersonic.headless.api.pojo.request.TagObjectReq; +import com.tencent.supersonic.headless.api.pojo.request.TermReq; import com.tencent.supersonic.headless.api.pojo.response.DataSetResp; import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp; import com.tencent.supersonic.headless.api.pojo.response.DimensionResp; @@ -48,12 +62,15 @@ import com.tencent.supersonic.headless.api.pojo.response.DomainResp; import com.tencent.supersonic.headless.api.pojo.response.MetricResp; import com.tencent.supersonic.headless.api.pojo.response.ModelResp; import com.tencent.supersonic.headless.api.pojo.response.TagObjectResp; -import io.swagger.models.auth.In; import lombok.extern.slf4j.Slf4j; import org.springframework.core.annotation.Order; import org.springframework.stereotype.Component; -import java.util.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; @Component @Slf4j @@ -131,7 +148,6 @@ public class S2VisitsDemo extends S2BaseDemo { submitText(chatId.intValue(), agentId, "超音数 访问次数"); submitText(chatId.intValue(), agentId, "按部门统计近7天访问次数"); submitText(chatId.intValue(), agentId, "alice 停留时长"); - submitText(chatId.intValue(), agentId, "访问次数最高的部门"); } private void submitText(int chatId, int agentId, String queryText) {