[improvement][chat]Refactor code logic in rule-based parsing.

This commit is contained in:
Jun Zhang
2024-11-09 15:00:08 +08:00
committed by jerryjzhang
parent d4a9d5a7e6
commit e0e167fd40
18 changed files with 154 additions and 233 deletions

View File

@@ -231,7 +231,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
log.info("rule begin replace metrics and revise filters!"); log.info("rule begin replace metrics and revise filters!");
validFilter(semanticQuery.getParseInfo().getDimensionFilters()); validFilter(semanticQuery.getParseInfo().getDimensionFilters());
validFilter(semanticQuery.getParseInfo().getMetricFilters()); validFilter(semanticQuery.getParseInfo().getMetricFilters());
semanticQuery.initS2Sql(dataSetSchema, user); semanticQuery.buildS2Sql(dataSetSchema);
} }
private QueryResult executeQuery(SemanticQuery semanticQuery, User user) throws Exception { private QueryResult executeQuery(SemanticQuery semanticQuery, User user) throws Exception {

View File

@@ -24,6 +24,10 @@ public class DataSetSchema implements Serializable {
private Set<SchemaElement> terms = new HashSet<>(); private Set<SchemaElement> terms = new HashSet<>();
private QueryConfig queryConfig; private QueryConfig queryConfig;
public Long getDataSetId() {
return dataSet.getDataSetId();
}
public SchemaElement getElement(SchemaElementType elementType, long elementID) { public SchemaElement getElement(SchemaElementType elementType, long elementID) {
Optional<SchemaElement> element = Optional.empty(); Optional<SchemaElement> element = Optional.empty();

View File

@@ -119,22 +119,26 @@ public class SemanticSchema implements Serializable {
return getElementsById(dataSetId, dataSets).orElse(null); return getElementsById(dataSetId, dataSets).orElse(null);
} }
public QueryConfig getQueryConfig(Long dataSetId) {
DataSetSchema first = dataSetSchemaList.stream().filter(
dataSetSchema -> dataSetId.equals(dataSetSchema.getDataSet().getDataSetId()))
.findFirst().orElse(null);
if (Objects.nonNull(first)) {
return first.getQueryConfig();
}
return null;
}
public List<SchemaElement> getDataSets() { public List<SchemaElement> getDataSets() {
List<SchemaElement> dataSets = new ArrayList<>(); List<SchemaElement> dataSets = new ArrayList<>();
dataSetSchemaList.forEach(d -> dataSets.add(d.getDataSet())); dataSetSchemaList.forEach(d -> dataSets.add(d.getDataSet()));
return dataSets; return dataSets;
} }
public DataSetSchema getDataSetSchema(Long dataSetId) {
return dataSetSchemaList.stream()
.filter(dataSetSchema -> dataSetId.equals(dataSetSchema.getDataSetId())).findFirst()
.orElse(null);
}
public QueryConfig getQueryConfig(Long dataSetId) {
DataSetSchema dataSetSchema = getDataSetSchema(dataSetId);
if (Objects.nonNull(dataSetSchema)) {
return dataSetSchema.getQueryConfig();
}
return null;
}
public Map<Long, DataSetSchema> getDataSetSchemaMap() { public Map<Long, DataSetSchema> getDataSetSchemaMap() {
if (CollectionUtils.isEmpty(dataSetSchemaList)) { if (CollectionUtils.isEmpty(dataSetSchemaList)) {
return new HashMap<>(); return new HashMap<>();

View File

@@ -188,36 +188,31 @@ public class QueryStructReq extends SemanticQueryReq {
List<Aggregator> aggregators = queryStructReq.getAggregators(); List<Aggregator> aggregators = queryStructReq.getAggregators();
if (!CollectionUtils.isEmpty(aggregators)) { if (!CollectionUtils.isEmpty(aggregators)) {
for (Aggregator aggregator : aggregators) { for (Aggregator aggregator : aggregators) {
selectItems.add(buildAggregatorSelectItem(aggregator, queryStructReq)); selectItems.add(buildAggregatorSelectItem(aggregator));
} }
} }
return selectItems; return selectItems;
} }
private SelectItem buildAggregatorSelectItem(Aggregator aggregator, private SelectItem buildAggregatorSelectItem(Aggregator aggregator) {
QueryStructReq queryStructReq) {
String columnName = aggregator.getColumn(); String columnName = aggregator.getColumn();
if (queryStructReq.getQueryType().isNativeAggQuery()) { Function function = new Function();
return new SelectItem(new Column(columnName)); AggOperatorEnum func = aggregator.getFunc();
} else { if (AggOperatorEnum.UNKNOWN.equals(func)) {
Function function = new Function(); func = AggOperatorEnum.SUM;
AggOperatorEnum func = aggregator.getFunc();
if (AggOperatorEnum.UNKNOWN.equals(func)) {
func = AggOperatorEnum.SUM;
}
function.setName(func.getOperator());
if (AggOperatorEnum.COUNT_DISTINCT.equals(func)) {
function.setName("count");
function.setDistinct(true);
}
function.setParameters(new ExpressionList(new Column(columnName)));
SelectItem selectExpressionItem = new SelectItem(function);
String alias = StringUtils.isNotBlank(aggregator.getAlias()) ? aggregator.getAlias()
: columnName;
selectExpressionItem.setAlias(new Alias(alias));
return selectExpressionItem;
} }
function.setName(func.getOperator());
if (AggOperatorEnum.COUNT_DISTINCT.equals(func)) {
function.setName("count");
function.setDistinct(true);
}
function.setParameters(new ExpressionList(new Column(columnName)));
SelectItem selectExpressionItem = new SelectItem(function);
String alias =
StringUtils.isNotBlank(aggregator.getAlias()) ? aggregator.getAlias() : columnName;
selectExpressionItem.setAlias(new Alias(alias));
return selectExpressionItem;
} }
private List<OrderByElement> buildOrderByElements(QueryStructReq queryStructReq) { private List<OrderByElement> buildOrderByElements(QueryStructReq queryStructReq) {
@@ -241,7 +236,7 @@ public class QueryStructReq extends SemanticQueryReq {
private GroupByElement buildGroupByElement(QueryStructReq queryStructReq) { private GroupByElement buildGroupByElement(QueryStructReq queryStructReq) {
List<String> groups = queryStructReq.getGroups(); List<String> groups = queryStructReq.getGroups();
if (!CollectionUtils.isEmpty(groups) && !queryStructReq.getQueryType().isNativeAggQuery()) { if (!CollectionUtils.isEmpty(groups) && !queryStructReq.getAggregators().isEmpty()) {
GroupByElement groupByElement = new GroupByElement(); GroupByElement groupByElement = new GroupByElement();
for (String group : groups) { for (String group : groups) {
groupByElement.addGroupByExpression(new Column(group)); groupByElement.addGroupByExpression(new Column(group));

View File

@@ -41,6 +41,10 @@ public class ChatQueryContext implements Serializable {
} }
} }
public DataSetSchema getDataSetSchema(Long dataSetId) {
return semanticSchema.getDataSetSchema(dataSetId);
}
public List<SemanticQuery> getCandidateQueries() { public List<SemanticQuery> getCandidateQueries() {
candidateQueries = candidateQueries.stream() candidateQueries = candidateQueries.stream()
.sorted(Comparator.comparing( .sorted(Comparator.comparing(

View File

@@ -1,18 +1,10 @@
package com.tencent.supersonic.headless.chat.parser; package com.tencent.supersonic.headless.chat.parser;
import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import java.util.List;
import java.util.Objects;
/** QueryTypeParser resolves query type as either AGGREGATE or DETAIL */ /** QueryTypeParser resolves query type as either AGGREGATE or DETAIL */
@Slf4j @Slf4j
@@ -20,34 +12,17 @@ public class QueryTypeParser implements SemanticParser {
@Override @Override
public void parse(ChatQueryContext chatQueryContext) { public void parse(ChatQueryContext chatQueryContext) {
chatQueryContext.getCandidateQueries().forEach(query -> {
SemanticParseInfo parseInfo = query.getParseInfo();
String s2SQL = parseInfo.getSqlInfo().getParsedS2SQL();
QueryType queryType = QueryType.DETAIL;
List<SemanticQuery> candidateQueries = chatQueryContext.getCandidateQueries(); if (SqlSelectFunctionHelper.hasAggregateFunction(s2SQL)) {
User user = chatQueryContext.getRequest().getUser(); queryType = QueryType.AGGREGATE;
}
for (SemanticQuery semanticQuery : candidateQueries) { parseInfo.setQueryType(queryType);
// 1.init S2SQL });
Long dataSetId = semanticQuery.getParseInfo().getDataSetId();
DataSetSchema dataSetSchema =
chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
semanticQuery.initS2Sql(dataSetSchema, user);
// 2.set queryType
QueryType queryType = getQueryType(semanticQuery);
semanticQuery.getParseInfo().setQueryType(queryType);
}
}
private QueryType getQueryType(SemanticQuery semanticQuery) {
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
SqlInfo sqlInfo = parseInfo.getSqlInfo();
if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getParsedS2SQL())) {
return QueryType.DETAIL;
}
if (SqlSelectFunctionHelper.hasAggregateFunction(sqlInfo.getParsedS2SQL())) {
return QueryType.AGGREGATE;
}
return QueryType.DETAIL;
} }
} }

View File

@@ -49,6 +49,7 @@ public class LLMResponseService {
parseInfo.setScore(queryCtx.getRequest().getQueryText().length() * (1 + weight)); parseInfo.setScore(queryCtx.getRequest().getQueryText().length() * (1 + weight));
parseInfo.setQueryMode(semanticQuery.getQueryMode()); parseInfo.setQueryMode(semanticQuery.getQueryMode());
parseInfo.getSqlInfo().setParsedS2SQL(s2SQL); parseInfo.getSqlInfo().setParsedS2SQL(s2SQL);
parseInfo.getSqlInfo().setCorrectedS2SQL(s2SQL);
queryCtx.getCandidateQueries().add(semanticQuery); queryCtx.getCandidateQueries().add(semanticQuery);
} }

View File

@@ -34,15 +34,13 @@ public class RuleSqlParser implements SemanticParser {
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(dataSetId); List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(dataSetId);
List<RuleSemanticQuery> queries = List<RuleSemanticQuery> queries =
RuleSemanticQuery.resolve(dataSetId, elementMatches, chatQueryContext); RuleSemanticQuery.resolve(dataSetId, elementMatches, chatQueryContext);
for (RuleSemanticQuery query : queries) { candidateQueries.addAll(queries);
query.fillParseInfo(chatQueryContext);
chatQueryContext.getCandidateQueries().add(query);
}
candidateQueries.addAll(chatQueryContext.getCandidateQueries());
chatQueryContext.getCandidateQueries().clear();
} }
chatQueryContext.setCandidateQueries(candidateQueries); chatQueryContext.setCandidateQueries(candidateQueries);
auxiliaryParsers.forEach(p -> p.parse(chatQueryContext)); auxiliaryParsers.forEach(p -> p.parse(chatQueryContext));
candidateQueries.forEach(query -> query.buildS2Sql(
chatQueryContext.getDataSetSchema(query.getParseInfo().getDataSetId())));
} }
} }

View File

@@ -1,87 +1,24 @@
package com.tencent.supersonic.headless.chat.query; package com.tencent.supersonic.headless.chat.query;
import com.tencent.supersonic.common.pojo.Aggregator;
import com.tencent.supersonic.common.pojo.Filter;
import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder; import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
import lombok.Data;
import lombok.ToString; import lombok.ToString;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import java.io.Serializable; import java.io.Serializable;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@Slf4j @Slf4j
@ToString @ToString
@Data
public abstract class BaseSemanticQuery implements SemanticQuery, Serializable { public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
protected SemanticParseInfo parseInfo = new SemanticParseInfo(); protected SemanticParseInfo parseInfo = new SemanticParseInfo();
@Override
public SemanticParseInfo getParseInfo() {
return parseInfo;
}
@Override
public void setParseInfo(SemanticParseInfo parseInfo) {
this.parseInfo = parseInfo;
}
protected QueryStructReq convertQueryStruct() {
return QueryReqBuilder.buildStructReq(parseInfo);
}
@Override @Override
public SemanticQueryReq buildSemanticQueryReq() { public SemanticQueryReq buildSemanticQueryReq() {
return QueryReqBuilder.buildS2SQLReq(parseInfo.getSqlInfo(), parseInfo.getDataSetId()); return QueryReqBuilder.buildS2SQLReq(parseInfo.getSqlInfo(), parseInfo.getDataSetId());
} }
protected void initS2SqlByStruct(DataSetSchema dataSetSchema) {
QueryStructReq queryStructReq = convertQueryStruct();
convertBizNameToName(dataSetSchema, queryStructReq);
QuerySqlReq querySQLReq = queryStructReq.convert();
parseInfo.getSqlInfo().setParsedS2SQL(querySQLReq.getSql());
parseInfo.getSqlInfo().setCorrectedS2SQL(querySQLReq.getSql());
}
protected void convertBizNameToName(DataSetSchema dataSetSchema,
QueryStructReq queryStructReq) {
Map<String, String> bizNameToName = dataSetSchema.getBizNameToName();
bizNameToName.putAll(TimeDimensionEnum.getNameToNameMap());
List<Order> orders = queryStructReq.getOrders();
if (CollectionUtils.isNotEmpty(orders)) {
for (Order order : orders) {
order.setColumn(bizNameToName.get(order.getColumn()));
}
}
List<Aggregator> aggregators = queryStructReq.getAggregators();
if (CollectionUtils.isNotEmpty(aggregators)) {
for (Aggregator aggregator : aggregators) {
aggregator.setColumn(bizNameToName.get(aggregator.getColumn()));
}
}
List<String> groups = queryStructReq.getGroups();
if (CollectionUtils.isNotEmpty(groups)) {
groups = groups.stream().map(bizNameToName::get).collect(Collectors.toList());
queryStructReq.setGroups(groups);
}
List<Filter> dimensionFilters = queryStructReq.getDimensionFilters();
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
dimensionFilters
.forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
}
List<Filter> metricFilters = queryStructReq.getMetricFilters();
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
metricFilters.forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
}
}
} }

View File

@@ -1,6 +1,5 @@
package com.tencent.supersonic.headless.chat.query; package com.tencent.supersonic.headless.chat.query;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
@@ -13,7 +12,7 @@ public interface SemanticQuery {
SemanticQueryReq buildSemanticQueryReq() throws SqlParseException; SemanticQueryReq buildSemanticQueryReq() throws SqlParseException;
void initS2Sql(DataSetSchema dataSetSchema, User user); void buildS2Sql(DataSetSchema dataSetSchema);
SemanticParseInfo getParseInfo(); SemanticParseInfo getParseInfo();

View File

@@ -1,6 +1,5 @@
package com.tencent.supersonic.headless.chat.query.llm.s2sql; package com.tencent.supersonic.headless.chat.query.llm.s2sql;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SqlInfo; import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.chat.query.QueryManager; import com.tencent.supersonic.headless.chat.query.QueryManager;
@@ -24,7 +23,7 @@ public class LLMSqlQuery extends LLMSemanticQuery {
} }
@Override @Override
public void initS2Sql(DataSetSchema dataSetSchema, User user) { public void buildS2Sql(DataSetSchema dataSetSchema) {
SqlInfo sqlInfo = parseInfo.getSqlInfo(); SqlInfo sqlInfo = parseInfo.getSqlInfo();
sqlInfo.setCorrectedS2SQL(sqlInfo.getParsedS2SQL()); sqlInfo.setCorrectedS2SQL(sqlInfo.getParsedS2SQL());
} }

View File

@@ -1,7 +1,10 @@
package com.tencent.supersonic.headless.chat.query.rule; package com.tencent.supersonic.headless.chat.query.rule;
import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.common.pojo.Aggregator;
import com.tencent.supersonic.common.pojo.Filter;
import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
@@ -10,6 +13,8 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq; import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.query.BaseSemanticQuery; import com.tencent.supersonic.headless.chat.query.BaseSemanticQuery;
@@ -17,6 +22,7 @@ import com.tencent.supersonic.headless.chat.query.QueryManager;
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder; import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
import lombok.ToString; import lombok.ToString;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import java.util.ArrayList; import java.util.ArrayList;
@@ -26,7 +32,6 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry; import java.util.Map.Entry;
import java.util.Objects; import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.TERM; import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.TERM;
@@ -50,14 +55,24 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
} }
@Override @Override
public void initS2Sql(DataSetSchema dataSetSchema, User user) { public void buildS2Sql(DataSetSchema dataSetSchema) {
initS2SqlByStruct(dataSetSchema); QueryStructReq queryStructReq = convertQueryStruct();
convertBizNameToName(dataSetSchema, queryStructReq);
QuerySqlReq querySQLReq = queryStructReq.convert();
parseInfo.getSqlInfo().setParsedS2SQL(querySQLReq.getSql());
parseInfo.getSqlInfo().setCorrectedS2SQL(querySQLReq.getSql());
} }
public void fillParseInfo(ChatQueryContext chatQueryContext) { protected QueryStructReq convertQueryStruct() {
parseInfo.setQueryMode(getQueryMode()); return QueryReqBuilder.buildStructReq(parseInfo);
}
protected void fillParseInfo(ChatQueryContext chatQueryContext, Long dataSetId) {
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema(); SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
parseInfo.setQueryMode(getQueryMode());
parseInfo.setDataSet(semanticSchema.getDataSet(dataSetId));
parseInfo.setQueryConfig(semanticSchema.getQueryConfig(dataSetId));
fillSchemaElement(parseInfo, semanticSchema); fillSchemaElement(parseInfo, semanticSchema);
fillScore(parseInfo); fillScore(parseInfo);
fillDateConfByInherited(parseInfo, chatQueryContext); fillDateConfByInherited(parseInfo, chatQueryContext);
@@ -110,12 +125,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
} }
private void fillSchemaElement(SemanticParseInfo parseInfo, SemanticSchema semanticSchema) { private void fillSchemaElement(SemanticParseInfo parseInfo, SemanticSchema semanticSchema) {
Set<Long> dataSetIds =
parseInfo.getElementMatches().stream().map(SchemaElementMatch::getElement)
.map(SchemaElement::getDataSetId).collect(Collectors.toSet());
Long dataSetId = dataSetIds.iterator().next();
parseInfo.setDataSet(semanticSchema.getDataSet(dataSetId));
parseInfo.setQueryConfig(semanticSchema.getQueryConfig(dataSetId));
Map<Long, List<SchemaElementMatch>> dim2Values = new HashMap<>(); Map<Long, List<SchemaElementMatch>> dim2Values = new HashMap<>();
for (SchemaElementMatch schemaMatch : parseInfo.getElementMatches()) { for (SchemaElementMatch schemaMatch : parseInfo.getElementMatches()) {
@@ -200,14 +210,15 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
public static List<RuleSemanticQuery> resolve(Long dataSetId, public static List<RuleSemanticQuery> resolve(Long dataSetId,
List<SchemaElementMatch> candidateElementMatches, ChatQueryContext chatQueryContext) { List<SchemaElementMatch> candidateElementMatches, ChatQueryContext chatQueryContext) {
List<RuleSemanticQuery> matchedQueries = new ArrayList<>(); List<RuleSemanticQuery> matchedQueries = new ArrayList<>();
for (RuleSemanticQuery semanticQuery : QueryManager.getRuleQueries()) { for (RuleSemanticQuery semanticQuery : QueryManager.getRuleQueries()) {
List<SchemaElementMatch> matches = List<SchemaElementMatch> matches =
semanticQuery.match(candidateElementMatches, chatQueryContext); semanticQuery.match(candidateElementMatches, chatQueryContext);
if (!matches.isEmpty()) { if (!matches.isEmpty()) {
RuleSemanticQuery query = RuleSemanticQuery query =
QueryManager.createRuleQuery(semanticQuery.getQueryMode()); QueryManager.createRuleQuery(semanticQuery.getQueryMode());
query.getParseInfo().getElementMatches().addAll(matches); query.getParseInfo().getElementMatches().addAll(matches);
query.fillParseInfo(chatQueryContext, dataSetId);
matchedQueries.add(query); matchedQueries.add(query);
} }
} }
@@ -217,4 +228,39 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
protected QueryMultiStructReq convertQueryMultiStruct() { protected QueryMultiStructReq convertQueryMultiStruct() {
return QueryReqBuilder.buildMultiStructReq(parseInfo); return QueryReqBuilder.buildMultiStructReq(parseInfo);
} }
protected void convertBizNameToName(DataSetSchema dataSetSchema,
QueryStructReq queryStructReq) {
Map<String, String> bizNameToName = dataSetSchema.getBizNameToName();
bizNameToName.putAll(TimeDimensionEnum.getNameToNameMap());
List<Order> orders = queryStructReq.getOrders();
if (CollectionUtils.isNotEmpty(orders)) {
for (Order order : orders) {
order.setColumn(bizNameToName.get(order.getColumn()));
}
}
List<Aggregator> aggregators = queryStructReq.getAggregators();
if (CollectionUtils.isNotEmpty(aggregators)) {
for (Aggregator aggregator : aggregators) {
aggregator.setColumn(bizNameToName.get(aggregator.getColumn()));
}
}
List<String> groups = queryStructReq.getGroups();
if (CollectionUtils.isNotEmpty(groups)) {
groups = groups.stream().map(bizNameToName::get).collect(Collectors.toList());
queryStructReq.setGroups(groups);
}
List<Filter> dimensionFilters = queryStructReq.getDimensionFilters();
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
dimensionFilters
.forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
}
List<Filter> metricFilters = queryStructReq.getMetricFilters();
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
metricFilters.forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
}
}
} }

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.headless.chat.query.rule.detail; package com.tencent.supersonic.headless.chat.query.rule.detail;
import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeMode; import com.tencent.supersonic.common.pojo.enums.TimeMode;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
@@ -29,12 +28,10 @@ public abstract class DetailSemanticQuery extends RuleSemanticQuery {
} }
@Override @Override
public void fillParseInfo(ChatQueryContext chatQueryContext) { public void fillParseInfo(ChatQueryContext chatQueryContext, Long dataSetId) {
super.fillParseInfo(chatQueryContext); super.fillParseInfo(chatQueryContext, dataSetId);
parseInfo.setQueryType(QueryType.DETAIL);
parseInfo.setLimit(parseInfo.getDetailLimit()); parseInfo.setLimit(parseInfo.getDetailLimit());
if (!needFillDateConf(chatQueryContext)) { if (!needFillDateConf(chatQueryContext)) {
return; return;
} }

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.headless.chat.query.rule.detail; package com.tencent.supersonic.headless.chat.query.rule.detail;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.ChatQueryContext;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
@@ -31,16 +31,14 @@ public class DetailValueQuery extends DetailSemanticQuery {
} }
@Override @Override
public void fillParseInfo(ChatQueryContext chatQueryContext) { public void fillParseInfo(ChatQueryContext chatQueryContext, Long dataSetId) {
super.fillParseInfo(chatQueryContext); super.fillParseInfo(chatQueryContext, dataSetId);
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema(); DataSetSchema dataSetSchema = chatQueryContext.getDataSetSchema(dataSetId);
parseInfo.getDimensions().addAll(semanticSchema.getDimensions()); parseInfo.getDimensions().addAll(dataSetSchema.getDimensions());
parseInfo.getDimensions().forEach(d -> { parseInfo.getDimensions().forEach(
parseInfo.getElementMatches() d -> parseInfo.getElementMatches().add(SchemaElementMatch.builder().element(d)
.add(SchemaElementMatch.builder().element(d).word(d.getName()).similarity(0) .word(d.getName()).similarity(0).detectWord(d.getName()).build()));
.isInherited(false).detectWord(d.getName()).build());
});
} }

View File

@@ -32,8 +32,9 @@ public abstract class MetricSemanticQuery extends RuleSemanticQuery {
} }
@Override @Override
public void fillParseInfo(ChatQueryContext chatQueryContext) { public void fillParseInfo(ChatQueryContext chatQueryContext, Long dataSetId) {
super.fillParseInfo(chatQueryContext); super.fillParseInfo(chatQueryContext, dataSetId);
parseInfo.setLimit(parseInfo.getMetricLimit()); parseInfo.setLimit(parseInfo.getMetricLimit());
fillDateInfo(chatQueryContext); fillDateInfo(chatQueryContext);
} }

View File

@@ -48,8 +48,8 @@ public class MetricTopNQuery extends MetricSemanticQuery {
} }
@Override @Override
public void fillParseInfo(ChatQueryContext chatQueryContext) { public void fillParseInfo(ChatQueryContext chatQueryContext, Long dataSetId) {
super.fillParseInfo(chatQueryContext); super.fillParseInfo(chatQueryContext, dataSetId);
parseInfo.setScore(parseInfo.getScore() + 2.0); parseInfo.setScore(parseInfo.getScore() + 2.0);
parseInfo.setAggType(AggregateTypeEnum.SUM); parseInfo.setAggType(AggregateTypeEnum.SUM);

View File

@@ -17,14 +17,12 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq; import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq; import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.chat.query.QueryManager;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils; import org.springframework.beans.BeanUtils;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
@@ -51,12 +49,6 @@ public class QueryReqBuilder {
chatFilter.getOperator(), chatFilter.getValue())) chatFilter.getOperator(), chatFilter.getValue()))
.collect(Collectors.toList()); .collect(Collectors.toList());
queryStructReq.setMetricFilters(metricFilters); queryStructReq.setMetricFilters(metricFilters);
addDateDimension(parseInfo);
if (isDateFieldAlreadyPresent(parseInfo, getDateField(parseInfo.getDateInfo()))) {
parseInfo.getDimensions().removeIf(schemaElement -> schemaElement.isPartitionTime());
}
queryStructReq.setGroups(parseInfo.getDimensions().stream().map(SchemaElement::getBizName) queryStructReq.setGroups(parseInfo.getDimensions().stream().map(SchemaElement::getBizName)
.collect(Collectors.toList())); .collect(Collectors.toList()));
queryStructReq.setLimit(parseInfo.getLimit()); queryStructReq.setLimit(parseInfo.getLimit());
@@ -155,51 +147,6 @@ public class QueryReqBuilder {
return aggregateType.name(); return aggregateType.name();
} }
private static void addDateDimension(SemanticParseInfo parseInfo) {
if (parseInfo == null || parseInfo.getDateInfo() == null) {
return;
}
if (shouldSkipAddingDateDimension(parseInfo)) {
return;
}
String dateField = getDateField(parseInfo.getDateInfo());
if (isDateFieldAlreadyPresent(parseInfo, dateField)) {
return;
}
SchemaElement dimension = new SchemaElement();
dimension.setBizName(dateField);
if (QueryManager.isMetricQuery(parseInfo.getQueryMode())) {
addDimension(parseInfo, dimension);
}
}
private static boolean shouldSkipAddingDateDimension(SemanticParseInfo parseInfo) {
return parseInfo.getAggType() != null
&& (parseInfo.getAggType().equals(AggregateTypeEnum.MAX)
|| parseInfo.getAggType().equals(AggregateTypeEnum.MIN))
&& !CollectionUtils.isEmpty(parseInfo.getDimensions());
}
private static boolean isDateFieldAlreadyPresent(SemanticParseInfo parseInfo,
String dateField) {
return parseInfo.getDimensions().stream()
.anyMatch(dimension -> dimension.getBizName().equalsIgnoreCase(dateField));
}
private static void addDimension(SemanticParseInfo parseInfo, SchemaElement dimension) {
List<String> timeDimensions = Arrays.asList(TimeDimensionEnum.DAY.getName(),
TimeDimensionEnum.WEEK.getName(), TimeDimensionEnum.MONTH.getName());
Set<SchemaElement> dimensions = parseInfo.getDimensions().stream()
.filter(d -> !timeDimensions.contains(d.getBizName().toLowerCase()))
.collect(Collectors.toSet());
dimensions.add(dimension);
parseInfo.setDimensions(dimensions);
}
public static Set<Order> getOrder(Set<Order> existingOrders, AggregateTypeEnum aggregator, public static Set<Order> getOrder(Set<Order> existingOrders, AggregateTypeEnum aggregator,
SchemaElement metric) { SchemaElement metric) {
if (existingOrders != null && !existingOrders.isEmpty()) { if (existingOrders != null && !existingOrders.isEmpty()) {

View File

@@ -18,7 +18,13 @@ import com.tencent.supersonic.chat.server.plugin.build.webservice.WebServiceQuer
import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.JoinCondition; import com.tencent.supersonic.common.pojo.JoinCondition;
import com.tencent.supersonic.common.pojo.ModelRela; import com.tencent.supersonic.common.pojo.ModelRela;
import com.tencent.supersonic.common.pojo.enums.*; import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.enums.AppModule;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.SensitiveLevelEnum;
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
import com.tencent.supersonic.common.util.ChatAppManager; import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.DataSetDetail; import com.tencent.supersonic.headless.api.pojo.DataSetDetail;
@@ -40,7 +46,15 @@ import com.tencent.supersonic.headless.api.pojo.enums.IdentifyType;
import com.tencent.supersonic.headless.api.pojo.enums.MetricDefineType; import com.tencent.supersonic.headless.api.pojo.enums.MetricDefineType;
import com.tencent.supersonic.headless.api.pojo.enums.SemanticType; import com.tencent.supersonic.headless.api.pojo.enums.SemanticType;
import com.tencent.supersonic.headless.api.pojo.enums.TagDefineType; import com.tencent.supersonic.headless.api.pojo.enums.TagDefineType;
import com.tencent.supersonic.headless.api.pojo.request.*; import com.tencent.supersonic.headless.api.pojo.request.DataSetReq;
import com.tencent.supersonic.headless.api.pojo.request.DictItemReq;
import com.tencent.supersonic.headless.api.pojo.request.DictSingleTaskReq;
import com.tencent.supersonic.headless.api.pojo.request.DimensionReq;
import com.tencent.supersonic.headless.api.pojo.request.DomainReq;
import com.tencent.supersonic.headless.api.pojo.request.MetricReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelReq;
import com.tencent.supersonic.headless.api.pojo.request.TagObjectReq;
import com.tencent.supersonic.headless.api.pojo.request.TermReq;
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp; import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp; import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp; import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
@@ -48,12 +62,15 @@ import com.tencent.supersonic.headless.api.pojo.response.DomainResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricResp; import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp; import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import com.tencent.supersonic.headless.api.pojo.response.TagObjectResp; import com.tencent.supersonic.headless.api.pojo.response.TagObjectResp;
import io.swagger.models.auth.In;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.core.annotation.Order; import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import java.util.*; import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@Component @Component
@Slf4j @Slf4j
@@ -131,7 +148,6 @@ public class S2VisitsDemo extends S2BaseDemo {
submitText(chatId.intValue(), agentId, "超音数 访问次数"); submitText(chatId.intValue(), agentId, "超音数 访问次数");
submitText(chatId.intValue(), agentId, "按部门统计近7天访问次数"); submitText(chatId.intValue(), agentId, "按部门统计近7天访问次数");
submitText(chatId.intValue(), agentId, "alice 停留时长"); submitText(chatId.intValue(), agentId, "alice 停留时长");
submitText(chatId.intValue(), agentId, "访问次数最高的部门");
} }
private void submitText(int chatId, int agentId, String queryText) { private void submitText(int chatId, int agentId, String queryText) {