(improvement)(Headless) Refactor ChatLayerService and SemanticLayerService (#1404)

Co-authored-by: lxwcodemonkey
This commit is contained in:
LXW
2024-07-14 11:23:47 +08:00
committed by GitHub
parent baff30550e
commit 407c8d4702
16 changed files with 496 additions and 514 deletions

View File

@@ -13,29 +13,70 @@ import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor;
import com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatContextService;
import com.tencent.supersonic.chat.server.service.ChatManageService;
import com.tencent.supersonic.chat.server.service.ChatQueryService;
import com.tencent.supersonic.chat.server.util.ComponentFactory;
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.jsqlparser.FieldExpression;
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.BeanMapper;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.EntityInfo;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
import com.tencent.supersonic.headless.chat.query.QueryManager;
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
import com.tencent.supersonic.headless.server.facade.service.RetrieveService;
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
import net.sf.jsqlparser.expression.operators.relational.ParenthesedExpressionList;
import net.sf.jsqlparser.schema.Column;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
@Slf4j
@@ -47,13 +88,11 @@ public class ChatQueryServiceImpl implements ChatQueryService {
@Autowired
private ChatLayerService chatLayerService;
@Autowired
private SemanticLayerService semanticLayerService;
@Autowired
private RetrieveService retrieveService;
@Autowired
private AgentService agentService;
@Autowired
private SemanticLayerService semanticLayerService;
@Autowired
private ChatContextService chatContextService;
private List<ChatQueryParser> chatQueryParsers = ComponentFactory.getChatParsers();
private List<ChatQueryExecutor> chatQueryExecutors = ComponentFactory.getChatExecutors();
@@ -160,23 +199,328 @@ public class ChatQueryServiceImpl implements ChatQueryService {
return executeContext;
}
//mainly used for executing after revising filters,for example:"fans_cnt>=100000"->"fans_cnt>500000",
//"style='流行'"->"style in ['流行','爱国']"
@Override
public Object queryData(ChatQueryDataReq chatQueryDataReq, User user) throws Exception {
Integer parseId = chatQueryDataReq.getParseId();
SemanticParseInfo parseInfo = chatManageService.getParseInfo(
chatQueryDataReq.getQueryId(), parseId);
QueryDataReq queryData = new QueryDataReq();
BeanMapper.mapper(chatQueryDataReq, queryData);
queryData.setParseInfo(parseInfo);
return chatLayerService.executeDirectQuery(queryData, user);
parseInfo = mergeSemanticParseInfo(parseInfo, chatQueryDataReq);
DataSetSchema dataSetSchema = semanticLayerService.getDataSetSchema(parseInfo.getDataSetId());
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
semanticQuery.setParseInfo(parseInfo);
List<String> fields = new ArrayList<>();
if (Objects.nonNull(parseInfo.getSqlInfo())
&& StringUtils.isNotBlank(parseInfo.getSqlInfo().getCorrectedS2SQL())) {
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
fields = SqlSelectHelper.getAllFields(correctorSql);
}
if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())
&& checkMetricReplace(fields, chatQueryDataReq.getMetrics())) {
//replace metrics
log.info("llm begin replace metrics!");
SchemaElement metricToReplace = chatQueryDataReq.getMetrics().iterator().next();
replaceMetrics(parseInfo, metricToReplace);
} else if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())) {
log.info("llm begin revise filters!");
String correctorSql = reviseCorrectS2SQL(chatQueryDataReq, parseInfo);
parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql);
semanticQuery.setParseInfo(parseInfo);
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
SemanticTranslateResp explain = semanticLayerService.translate(semanticQueryReq, user);
parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL());
} else {
log.info("rule begin replace metrics and revise filters!");
//remove unvalid filters
validFilter(semanticQuery.getParseInfo().getDimensionFilters());
validFilter(semanticQuery.getParseInfo().getMetricFilters());
//init s2sql
semanticQuery.initS2Sql(dataSetSchema, user);
}
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
QueryResult queryResult = doExecution(semanticQueryReq, semanticQuery.getParseInfo(), user);
queryResult.setChatContext(semanticQuery.getParseInfo());
SemanticLayerService semanticService = ContextUtils.getBean(SemanticLayerService.class);
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, user);
queryResult.setEntityInfo(entityInfo);
return queryResult;
}
private boolean checkMetricReplace(List<String> oriFields, Set<SchemaElement> metrics) {
if (CollectionUtils.isEmpty(oriFields)) {
return false;
}
if (CollectionUtils.isEmpty(metrics)) {
return false;
}
List<String> metricNames = metrics.stream().map(SchemaElement::getName).collect(Collectors.toList());
return !oriFields.containsAll(metricNames);
}
private String reviseCorrectS2SQL(ChatQueryDataReq queryData, SemanticParseInfo parseInfo) {
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
Map<String, Map<String, String>> havingFiledNameToValueMap = new HashMap<>();
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
log.info("correctorSql before replacing:{}", correctorSql);
// get where filter and having filter
List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctorSql);
List<FieldExpression> havingExpressionList = SqlSelectHelper.getHavingExpressions(correctorSql);
List<Expression> addWhereConditions = new ArrayList<>();
List<Expression> addHavingConditions = new ArrayList<>();
Set<String> removeWhereFieldNames = new HashSet<>();
Set<String> removeHavingFieldNames = new HashSet<>();
// replace where filter
updateFilters(whereExpressionList, queryData.getDimensionFilters(),
parseInfo.getDimensionFilters(), addWhereConditions, removeWhereFieldNames);
updateDateInfo(queryData, parseInfo, filedNameToValueMap,
whereExpressionList, addWhereConditions, removeWhereFieldNames);
correctorSql = SqlReplaceHelper.replaceValue(correctorSql, filedNameToValueMap);
correctorSql = SqlRemoveHelper.removeWhereCondition(correctorSql, removeWhereFieldNames);
// replace having filter
updateFilters(havingExpressionList, queryData.getDimensionFilters(),
parseInfo.getDimensionFilters(), addHavingConditions, removeHavingFieldNames);
correctorSql = SqlReplaceHelper.replaceHavingValue(correctorSql, havingFiledNameToValueMap);
correctorSql = SqlRemoveHelper.removeHavingCondition(correctorSql, removeHavingFieldNames);
correctorSql = SqlAddHelper.addWhere(correctorSql, addWhereConditions);
correctorSql = SqlAddHelper.addHaving(correctorSql, addHavingConditions);
log.info("correctorSql after replacing:{}", correctorSql);
return correctorSql;
}
private void replaceMetrics(SemanticParseInfo parseInfo, SchemaElement metric) {
List<String> oriMetrics = parseInfo.getMetrics().stream()
.map(SchemaElement::getName).collect(Collectors.toList());
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
log.info("before replaceMetrics:{}", correctorSql);
log.info("filteredMetrics:{},metrics:{}", oriMetrics, metric);
Map<String, Pair<String, String>> fieldMap = new HashMap<>();
if (!CollectionUtils.isEmpty(oriMetrics) && !oriMetrics.contains(metric.getName())) {
fieldMap.put(oriMetrics.get(0), Pair.of(metric.getName(), metric.getDefaultAgg()));
correctorSql = SqlReplaceHelper.replaceAggFields(correctorSql, fieldMap);
}
log.info("after replaceMetrics:{}", correctorSql);
parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql);
}
private QueryResult doExecution(SemanticQueryReq semanticQueryReq,
SemanticParseInfo parseInfo, User user) throws Exception {
SemanticQueryResp queryResp = semanticLayerService.queryByReq(semanticQueryReq, user);
QueryResult queryResult = new QueryResult();
if (queryResp != null) {
queryResult.setQueryAuthorization(queryResp.getQueryAuthorization());
}
String sql = queryResp == null ? null : queryResp.getSql();
List<Map<String, Object>> resultList = queryResp == null ? new ArrayList<>()
: queryResp.getResultList();
List<QueryColumn> columns = queryResp == null ? new ArrayList<>() : queryResp.getColumns();
queryResult.setQuerySql(sql);
queryResult.setQueryResults(resultList);
queryResult.setQueryColumns(columns);
queryResult.setQueryMode(parseInfo.getQueryMode());
queryResult.setQueryState(QueryState.SUCCESS);
return queryResult;
}
private void updateDateInfo(ChatQueryDataReq queryData, SemanticParseInfo parseInfo,
Map<String, Map<String, String>> filedNameToValueMap,
List<FieldExpression> fieldExpressionList,
List<Expression> addConditions,
Set<String> removeFieldNames) {
if (Objects.isNull(queryData.getDateInfo())) {
return;
}
if (queryData.getDateInfo().getUnit() > 1) {
queryData.getDateInfo().setStartDate(DateUtils.getBeforeDate(queryData.getDateInfo().getUnit() + 1));
queryData.getDateInfo().setEndDate(DateUtils.getBeforeDate(1));
}
// startDate equals to endDate
for (FieldExpression fieldExpression : fieldExpressionList) {
if (TimeDimensionEnum.DAY.getChName().equals(fieldExpression.getFieldName())) {
// first remove,then add
removeFieldNames.add(TimeDimensionEnum.DAY.getChName());
GreaterThanEquals greaterThanEquals = new GreaterThanEquals();
addTimeFilters(queryData.getDateInfo().getStartDate(), greaterThanEquals, addConditions);
MinorThanEquals minorThanEquals = new MinorThanEquals();
addTimeFilters(queryData.getDateInfo().getEndDate(), minorThanEquals, addConditions);
break;
}
}
for (FieldExpression fieldExpression : fieldExpressionList) {
for (QueryFilter queryFilter : queryData.getDimensionFilters()) {
if (queryFilter.getOperator().equals(FilterOperatorEnum.LIKE)
&& FilterOperatorEnum.LIKE.getValue().toLowerCase().equals(
fieldExpression.getOperator().toLowerCase())) {
Map<String, String> replaceMap = new HashMap<>();
String preValue = fieldExpression.getFieldValue().toString();
String curValue = queryFilter.getValue().toString();
if (preValue.startsWith("%")) {
curValue = "%" + curValue;
}
if (preValue.endsWith("%")) {
curValue = curValue + "%";
}
replaceMap.put(preValue, curValue);
filedNameToValueMap.put(fieldExpression.getFieldName(), replaceMap);
break;
}
}
}
parseInfo.setDateInfo(queryData.getDateInfo());
}
private <T extends ComparisonOperator> void addTimeFilters(String date,
T comparisonExpression,
List<Expression> addConditions) {
Column column = new Column(TimeDimensionEnum.DAY.getChName());
StringValue stringValue = new StringValue(date);
comparisonExpression.setLeftExpression(column);
comparisonExpression.setRightExpression(stringValue);
addConditions.add(comparisonExpression);
}
private void updateFilters(List<FieldExpression> fieldExpressionList,
Set<QueryFilter> metricFilters,
Set<QueryFilter> contextMetricFilters,
List<Expression> addConditions,
Set<String> removeFieldNames) {
if (org.apache.commons.collections.CollectionUtils.isEmpty(metricFilters)) {
return;
}
for (QueryFilter dslQueryFilter : metricFilters) {
for (FieldExpression fieldExpression : fieldExpressionList) {
if (fieldExpression.getFieldName() != null
&& fieldExpression.getFieldName().contains(dslQueryFilter.getName())) {
removeFieldNames.add(dslQueryFilter.getName());
if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.EQUALS)) {
EqualsTo equalsTo = new EqualsTo();
addWhereFilters(dslQueryFilter, equalsTo, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.GREATER_THAN_EQUALS)) {
GreaterThanEquals greaterThanEquals = new GreaterThanEquals();
addWhereFilters(dslQueryFilter, greaterThanEquals, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.GREATER_THAN)) {
GreaterThan greaterThan = new GreaterThan();
addWhereFilters(dslQueryFilter, greaterThan, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.MINOR_THAN_EQUALS)) {
MinorThanEquals minorThanEquals = new MinorThanEquals();
addWhereFilters(dslQueryFilter, minorThanEquals, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.MINOR_THAN)) {
MinorThan minorThan = new MinorThan();
addWhereFilters(dslQueryFilter, minorThan, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.IN)) {
InExpression inExpression = new InExpression();
addWhereInFilters(dslQueryFilter, inExpression, contextMetricFilters, addConditions);
}
break;
}
}
}
}
// add in condition to sql where condition
private void addWhereInFilters(QueryFilter dslQueryFilter,
InExpression inExpression,
Set<QueryFilter> contextMetricFilters,
List<Expression> addConditions) {
Column column = new Column(dslQueryFilter.getName());
ParenthesedExpressionList parenthesedExpressionList = new ParenthesedExpressionList<>();
List<String> valueList = JsonUtil.toList(
JsonUtil.toString(dslQueryFilter.getValue()), String.class);
if (org.apache.commons.collections.CollectionUtils.isEmpty(valueList)) {
return;
}
valueList.stream().forEach(o -> {
StringValue stringValue = new StringValue(o);
parenthesedExpressionList.add(stringValue);
});
inExpression.setLeftExpression(column);
inExpression.setRightExpression(parenthesedExpressionList);
addConditions.add(inExpression);
contextMetricFilters.stream().forEach(o -> {
if (o.getName().equals(dslQueryFilter.getName())) {
o.setValue(dslQueryFilter.getValue());
o.setOperator(dslQueryFilter.getOperator());
}
});
}
// add where filter
private <T extends ComparisonOperator> void addWhereFilters(QueryFilter dslQueryFilter,
T comparisonExpression,
Set<QueryFilter> contextMetricFilters,
List<Expression> addConditions) {
String columnName = dslQueryFilter.getName();
if (StringUtils.isNotBlank(dslQueryFilter.getFunction())) {
columnName = dslQueryFilter.getFunction() + "(" + dslQueryFilter.getName() + ")";
}
if (Objects.isNull(dslQueryFilter.getValue())) {
return;
}
Column column = new Column(columnName);
comparisonExpression.setLeftExpression(column);
if (StringUtils.isNumeric(dslQueryFilter.getValue().toString())) {
LongValue longValue = new LongValue(Long.parseLong(dslQueryFilter.getValue().toString()));
comparisonExpression.setRightExpression(longValue);
} else {
StringValue stringValue = new StringValue(dslQueryFilter.getValue().toString());
comparisonExpression.setRightExpression(stringValue);
}
addConditions.add(comparisonExpression);
contextMetricFilters.stream().forEach(o -> {
if (o.getName().equals(dslQueryFilter.getName())) {
o.setValue(dslQueryFilter.getValue());
o.setOperator(dslQueryFilter.getOperator());
}
});
}
private SemanticParseInfo mergeSemanticParseInfo(SemanticParseInfo parseInfo,
ChatQueryDataReq queryData) {
if (LLMSqlQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
return parseInfo;
}
if (!CollectionUtils.isEmpty(queryData.getDimensions())) {
parseInfo.setDimensions(queryData.getDimensions());
}
if (!CollectionUtils.isEmpty(queryData.getMetrics())) {
parseInfo.setMetrics(queryData.getMetrics());
}
if (!CollectionUtils.isEmpty(queryData.getDimensionFilters())) {
parseInfo.setDimensionFilters(queryData.getDimensionFilters());
}
if (!CollectionUtils.isEmpty(queryData.getMetricFilters())) {
parseInfo.setMetricFilters(queryData.getMetricFilters());
}
if (Objects.nonNull(queryData.getDateInfo())) {
parseInfo.setDateInfo(queryData.getDateInfo());
}
return parseInfo;
}
private void validFilter(Set<QueryFilter> filters) {
for (QueryFilter queryFilter : filters) {
if (Objects.isNull(queryFilter.getValue())) {
filters.remove(queryFilter);
}
if (queryFilter.getOperator().equals(FilterOperatorEnum.IN) && CollectionUtils.isEmpty(
JsonUtil.toList(JsonUtil.toString(queryFilter.getValue()), String.class))) {
filters.remove(queryFilter);
}
}
}
@Override
public Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception {
public Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) {
Integer agentId = dimensionValueReq.getAgentId();
Agent agent = agentService.getAgent(agentId);
dimensionValueReq.setDataSetIds(agent.getDataSetIds());
return chatLayerService.queryDimensionValue(dimensionValueReq, user);
return semanticLayerService.queryDimensionValue(dimensionValueReq, user);
}
public void saveQueryResult(ChatExecuteReq chatExecuteReq, QueryResult queryResult) {

View File

@@ -1,14 +1,16 @@
package com.tencent.supersonic.headless.api.pojo;
import lombok.Data;
import org.apache.commons.collections.CollectionUtils;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.Data;
import org.apache.commons.collections.CollectionUtils;
@Data
public class DataSetSchema {
@@ -57,6 +59,14 @@ public class DataSetSchema {
}
}
public Map<String, String> getBizNameToName() {
List<SchemaElement> allElements = new ArrayList<>();
allElements.addAll(getDimensions());
allElements.addAll(getMetrics());
return allElements.stream()
.collect(Collectors.toMap(SchemaElement::getBizName, SchemaElement::getName, (k1, k2) -> k1));
}
public TimeDefaultConfig getTagTypeTimeDefaultConfig() {
if (queryConfig == null) {
return null;

View File

@@ -150,14 +150,6 @@ public class SemanticSchema implements Serializable {
return dataSets;
}
public Map<String, String> getBizNameToName(Long dataSetId) {
List<SchemaElement> allElements = new ArrayList<>();
allElements.addAll(getDimensions(dataSetId));
allElements.addAll(getMetrics(dataSetId));
return allElements.stream()
.collect(Collectors.toMap(SchemaElement::getBizName, SchemaElement::getName, (k1, k2) -> k1));
}
public Map<Long, DataSetSchema> getDataSetSchemaMap() {
if (CollectionUtils.isEmpty(dataSetSchemaList)) {
return new HashMap<>();

View File

@@ -1,6 +1,8 @@
package com.tencent.supersonic.headless.api.pojo.request;
import com.tencent.supersonic.common.pojo.DateConf;
import lombok.Data;
import org.apache.commons.lang3.StringUtils;
import javax.validation.constraints.NotNull;
import java.util.Set;
@@ -21,4 +23,11 @@ public class DimensionValueReq {
private Set<Long> dataSetIds;
private DateConf dateInfo = new DateConf();
private String dimensionBizName;
public String getBizName() {
return StringUtils.isBlank(bizName) ? dimensionBizName : bizName;
}
}

View File

@@ -1,16 +0,0 @@
package com.tencent.supersonic.headless.api.pojo.request;
import com.tencent.supersonic.common.pojo.DateConf;
import lombok.Data;
import lombok.ToString;
@Data
@ToString
public class QueryDimValueReq {
private Long modelId;
private String dimensionBizName;
private String value;
private DateConf dateInfo = new DateConf();
}

View File

@@ -4,6 +4,7 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
@@ -35,7 +36,10 @@ public class QueryTypeParser implements SemanticParser {
for (SemanticQuery semanticQuery : candidateQueries) {
// 1.init S2SQL
semanticQuery.initS2Sql(chatQueryContext.getSemanticSchema(), user);
Long dataSetId = semanticQuery.getParseInfo().getDataSetId();
DataSetSchema dataSetSchema = chatQueryContext.getSemanticSchema()
.getDataSetSchemaMap().get(dataSetId);
semanticQuery.initS2Sql(dataSetSchema, user);
// 2.set queryType
QueryType queryType = getQueryType(chatQueryContext, semanticQuery);
semanticQuery.getParseInfo().setQueryType(queryType);

View File

@@ -6,8 +6,8 @@ import com.tencent.supersonic.common.pojo.Filter;
import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
@@ -43,8 +43,8 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
return QueryReqBuilder.buildStructReq(parseInfo);
}
protected void convertBizNameToName(SemanticSchema semanticSchema, QueryStructReq queryStructReq) {
Map<String, String> bizNameToName = semanticSchema.getBizNameToName(queryStructReq.getDataSetId());
protected void convertBizNameToName(DataSetSchema dataSetSchema, QueryStructReq queryStructReq) {
Map<String, String> bizNameToName = dataSetSchema.getBizNameToName();
bizNameToName.putAll(TimeDimensionEnum.getNameToNameMap());
List<Order> orders = queryStructReq.getOrders();
@@ -74,14 +74,14 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
}
}
protected void initS2SqlByStruct(SemanticSchema semanticSchema) {
protected void initS2SqlByStruct(DataSetSchema dataSetSchema) {
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
boolean s2sqlEnable = Boolean.valueOf(parserConfig.getParameterValue(PARSER_S2SQL_ENABLE));
if (!s2sqlEnable) {
return;
}
QueryStructReq queryStructReq = convertQueryStruct();
convertBizNameToName(semanticSchema, queryStructReq);
convertBizNameToName(dataSetSchema, queryStructReq);
QuerySqlReq querySQLReq = queryStructReq.convert();
parseInfo.getSqlInfo().setParsedS2SQL(querySQLReq.getSql());
parseInfo.getSqlInfo().setCorrectedS2SQL(querySQLReq.getSql());

View File

@@ -1,8 +1,8 @@
package com.tencent.supersonic.headless.chat.query;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import org.apache.calcite.sql.parser.SqlParseException;
@@ -15,7 +15,7 @@ public interface SemanticQuery {
SemanticQueryReq buildSemanticQueryReq() throws SqlParseException;
void initS2Sql(SemanticSchema semanticSchema, User user);
void initS2Sql(DataSetSchema dataSetSchema, User user);
SemanticParseInfo getParseInfo();

View File

@@ -1,12 +1,12 @@
package com.tencent.supersonic.headless.chat.query.llm.s2sql;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
import com.tencent.supersonic.headless.chat.query.QueryManager;
import com.tencent.supersonic.headless.chat.query.llm.LLMSemanticQuery;
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
@@ -31,7 +31,7 @@ public class LLMSqlQuery extends LLMSemanticQuery {
}
@Override
public void initS2Sql(SemanticSchema semanticSchema, User user) {
public void initS2Sql(DataSetSchema dataSetSchema, User user) {
SqlInfo sqlInfo = parseInfo.getSqlInfo();
sqlInfo.setCorrectedS2SQL(sqlInfo.getParsedS2SQL());
}

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.chat.query.rule;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
@@ -13,9 +14,13 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
import com.tencent.supersonic.headless.chat.query.BaseSemanticQuery;
import com.tencent.supersonic.headless.chat.query.QueryManager;
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
import lombok.ToString;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
@@ -24,9 +29,6 @@ import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.ToString;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
@Slf4j
@ToString
@@ -44,8 +46,8 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
}
@Override
public void initS2Sql(SemanticSchema semanticSchema, User user) {
initS2SqlByStruct(semanticSchema);
public void initS2Sql(DataSetSchema dataSetSchema, User user) {
initS2SqlByStruct(dataSetSchema);
}
public void fillParseInfo(ChatQueryContext chatQueryContext) {

View File

@@ -2,15 +2,12 @@ package com.tencent.supersonic.headless.server.facade.service;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.SqlEvaluation;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryMapReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp;
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
/***dd
* SemanticLayerService for query and search
@@ -21,10 +18,6 @@ public interface ChatLayerService {
ParseResp performParsing(QueryNLReq queryNLReq);
QueryResult executeDirectQuery(QueryDataReq queryData, User user) throws Exception;
Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception;
MapInfoResp map(QueryMapReq queryMapReq);
void correct(QuerySqlReq querySqlReq, User user);

View File

@@ -4,11 +4,11 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.EntityInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
import com.tencent.supersonic.headless.api.pojo.response.ItemResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
import java.util.List;
@@ -23,7 +23,7 @@ public interface SemanticLayerService {
SemanticQueryResp queryByReq(SemanticQueryReq queryReq, User user) throws Exception;
SemanticQueryResp queryDimValue(QueryDimValueReq queryDimValueReq, User user);
SemanticQueryResp queryDimensionValue(DimensionValueReq dimensionValueReq, User user);
EntityInfo getEntityInfo(SemanticParseInfo parseInfo, DataSetSchema dataSetSchema, User user);

View File

@@ -2,20 +2,8 @@ package com.tencent.supersonic.headless.server.facade.service.impl;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.jsqlparser.FieldExpression;
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.EntityInfo;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
@@ -25,60 +13,27 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.SqlEvaluation;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import com.tencent.supersonic.headless.api.pojo.request.QueryMapReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.DataSetMapInfo;
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp;
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.corrector.GrammarCorrector;
import com.tencent.supersonic.headless.chat.corrector.SchemaCorrector;
import com.tencent.supersonic.headless.chat.knowledge.HanlpMapResult;
import com.tencent.supersonic.headless.chat.knowledge.KnowledgeBaseService;
import com.tencent.supersonic.headless.chat.knowledge.SearchService;
import com.tencent.supersonic.headless.chat.knowledge.builder.BaseWordBuilder;
import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper;
import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper;
import com.tencent.supersonic.headless.chat.query.QueryManager;
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
import com.tencent.supersonic.headless.server.utils.ChatWorkflowEngine;
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
import com.tencent.supersonic.headless.server.utils.ChatWorkflowEngine;
import com.tencent.supersonic.headless.server.utils.ComponentFactory;
import com.tencent.supersonic.headless.server.web.service.DataSetService;
import com.tencent.supersonic.headless.server.web.service.SchemaService;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
import net.sf.jsqlparser.expression.operators.relational.ParenthesedExpressionList;
import net.sf.jsqlparser.schema.Column;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@@ -86,7 +41,6 @@ import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@@ -97,13 +51,9 @@ import java.util.stream.Collectors;
@Service
@Slf4j
public class S2ChatLayerService implements ChatLayerService {
@Autowired
private SemanticLayerService semanticLayerService;
@Autowired
private SchemaService schemaService;
@Autowired
private KnowledgeBaseService knowledgeBaseService;
@Autowired
private DataSetService dataSetService;
@Autowired
private ChatWorkflowEngine chatWorkflowEngine;
@@ -166,386 +116,6 @@ public class S2ChatLayerService implements ChatLayerService {
return queryCtx;
}
//mainly used for executing after revising filters,for example:"fans_cnt>=100000"->"fans_cnt>500000",
//"style='流行'"->"style in ['流行','爱国']"
@Override
public QueryResult executeDirectQuery(QueryDataReq queryData, User user) throws Exception {
SemanticParseInfo parseInfo = getSemanticParseInfo(queryData);
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
semanticQuery.setParseInfo(parseInfo);
List<String> fields = new ArrayList<>();
if (Objects.nonNull(parseInfo.getSqlInfo())
&& StringUtils.isNotBlank(parseInfo.getSqlInfo().getCorrectedS2SQL())) {
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
fields = SqlSelectHelper.getAllFields(correctorSql);
}
if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())
&& checkMetricReplace(fields, queryData.getMetrics())) {
//replace metrics
log.info("llm begin replace metrics!");
SchemaElement metricToReplace = queryData.getMetrics().iterator().next();
replaceMetrics(parseInfo, metricToReplace);
} else if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())) {
log.info("llm begin revise filters!");
String correctorSql = reviseCorrectS2SQL(queryData, parseInfo);
parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql);
semanticQuery.setParseInfo(parseInfo);
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
SemanticTranslateResp explain = semanticLayerService.translate(semanticQueryReq, user);
parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL());
} else {
log.info("rule begin replace metrics and revise filters!");
//remove unvalid filters
validFilter(semanticQuery.getParseInfo().getDimensionFilters());
validFilter(semanticQuery.getParseInfo().getMetricFilters());
//init s2sql
semanticQuery.initS2Sql(semanticSchema, user);
QueryNLReq queryNLReq = new QueryNLReq();
queryNLReq.setQueryFilters(new QueryFilters());
queryNLReq.setUser(user);
}
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
QueryResult queryResult = doExecution(semanticQueryReq, semanticQuery.getParseInfo(), user);
queryResult.setChatContext(semanticQuery.getParseInfo());
DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(parseInfo.getDataSetId());
SemanticLayerService semanticService = ContextUtils.getBean(SemanticLayerService.class);
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, user);
queryResult.setEntityInfo(entityInfo);
return queryResult;
}
private QueryResult doExecution(SemanticQueryReq semanticQueryReq,
SemanticParseInfo parseInfo, User user) throws Exception {
SemanticQueryResp queryResp = semanticLayerService.queryByReq(semanticQueryReq, user);
QueryResult queryResult = new QueryResult();
if (queryResp != null) {
queryResult.setQueryAuthorization(queryResp.getQueryAuthorization());
}
String sql = queryResp == null ? null : queryResp.getSql();
List<Map<String, Object>> resultList = queryResp == null ? new ArrayList<>()
: queryResp.getResultList();
List<QueryColumn> columns = queryResp == null ? new ArrayList<>() : queryResp.getColumns();
queryResult.setQuerySql(sql);
queryResult.setQueryResults(resultList);
queryResult.setQueryColumns(columns);
queryResult.setQueryMode(parseInfo.getQueryMode());
queryResult.setQueryState(QueryState.SUCCESS);
return queryResult;
}
private boolean checkMetricReplace(List<String> oriFields, Set<SchemaElement> metrics) {
if (CollectionUtils.isEmpty(oriFields)) {
return false;
}
if (CollectionUtils.isEmpty(metrics)) {
return false;
}
List<String> metricNames = metrics.stream().map(SchemaElement::getName).collect(Collectors.toList());
return !oriFields.containsAll(metricNames);
}
private String reviseCorrectS2SQL(QueryDataReq queryData, SemanticParseInfo parseInfo) {
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
Map<String, Map<String, String>> havingFiledNameToValueMap = new HashMap<>();
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
log.info("correctorSql before replacing:{}", correctorSql);
// get where filter and having filter
List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctorSql);
List<FieldExpression> havingExpressionList = SqlSelectHelper.getHavingExpressions(correctorSql);
List<Expression> addWhereConditions = new ArrayList<>();
List<Expression> addHavingConditions = new ArrayList<>();
Set<String> removeWhereFieldNames = new HashSet<>();
Set<String> removeHavingFieldNames = new HashSet<>();
// replace where filter
updateFilters(whereExpressionList, queryData.getDimensionFilters(),
parseInfo.getDimensionFilters(), addWhereConditions, removeWhereFieldNames);
updateDateInfo(queryData, parseInfo, filedNameToValueMap,
whereExpressionList, addWhereConditions, removeWhereFieldNames);
correctorSql = SqlReplaceHelper.replaceValue(correctorSql, filedNameToValueMap);
correctorSql = SqlRemoveHelper.removeWhereCondition(correctorSql, removeWhereFieldNames);
// replace having filter
updateFilters(havingExpressionList, queryData.getDimensionFilters(),
parseInfo.getDimensionFilters(), addHavingConditions, removeHavingFieldNames);
correctorSql = SqlReplaceHelper.replaceHavingValue(correctorSql, havingFiledNameToValueMap);
correctorSql = SqlRemoveHelper.removeHavingCondition(correctorSql, removeHavingFieldNames);
correctorSql = SqlAddHelper.addWhere(correctorSql, addWhereConditions);
correctorSql = SqlAddHelper.addHaving(correctorSql, addHavingConditions);
log.info("correctorSql after replacing:{}", correctorSql);
return correctorSql;
}
private void replaceMetrics(SemanticParseInfo parseInfo, SchemaElement metric) {
List<String> oriMetrics = parseInfo.getMetrics().stream()
.map(SchemaElement::getName).collect(Collectors.toList());
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
log.info("before replaceMetrics:{}", correctorSql);
log.info("filteredMetrics:{},metrics:{}", oriMetrics, metric);
Map<String, Pair<String, String>> fieldMap = new HashMap<>();
if (CollectionUtils.isNotEmpty(oriMetrics) && !oriMetrics.contains(metric.getName())) {
fieldMap.put(oriMetrics.get(0), Pair.of(metric.getName(), metric.getDefaultAgg()));
correctorSql = SqlReplaceHelper.replaceAggFields(correctorSql, fieldMap);
}
log.info("after replaceMetrics:{}", correctorSql);
parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql);
}
private void updateDateInfo(QueryDataReq queryData, SemanticParseInfo parseInfo,
Map<String, Map<String, String>> filedNameToValueMap,
List<FieldExpression> fieldExpressionList,
List<Expression> addConditions,
Set<String> removeFieldNames) {
if (Objects.isNull(queryData.getDateInfo())) {
return;
}
if (queryData.getDateInfo().getUnit() > 1) {
queryData.getDateInfo().setStartDate(DateUtils.getBeforeDate(queryData.getDateInfo().getUnit() + 1));
queryData.getDateInfo().setEndDate(DateUtils.getBeforeDate(1));
}
// startDate equals to endDate
for (FieldExpression fieldExpression : fieldExpressionList) {
if (TimeDimensionEnum.DAY.getChName().equals(fieldExpression.getFieldName())) {
// first remove,then add
removeFieldNames.add(TimeDimensionEnum.DAY.getChName());
GreaterThanEquals greaterThanEquals = new GreaterThanEquals();
addTimeFilters(queryData.getDateInfo().getStartDate(), greaterThanEquals, addConditions);
MinorThanEquals minorThanEquals = new MinorThanEquals();
addTimeFilters(queryData.getDateInfo().getEndDate(), minorThanEquals, addConditions);
break;
}
}
for (FieldExpression fieldExpression : fieldExpressionList) {
for (QueryFilter queryFilter : queryData.getDimensionFilters()) {
if (queryFilter.getOperator().equals(FilterOperatorEnum.LIKE)
&& FilterOperatorEnum.LIKE.getValue().toLowerCase().equals(
fieldExpression.getOperator().toLowerCase())) {
Map<String, String> replaceMap = new HashMap<>();
String preValue = fieldExpression.getFieldValue().toString();
String curValue = queryFilter.getValue().toString();
if (preValue.startsWith("%")) {
curValue = "%" + curValue;
}
if (preValue.endsWith("%")) {
curValue = curValue + "%";
}
replaceMap.put(preValue, curValue);
filedNameToValueMap.put(fieldExpression.getFieldName(), replaceMap);
break;
}
}
}
parseInfo.setDateInfo(queryData.getDateInfo());
}
private <T extends ComparisonOperator> void addTimeFilters(String date,
T comparisonExpression,
List<Expression> addConditions) {
Column column = new Column(TimeDimensionEnum.DAY.getChName());
StringValue stringValue = new StringValue(date);
comparisonExpression.setLeftExpression(column);
comparisonExpression.setRightExpression(stringValue);
addConditions.add(comparisonExpression);
}
private void updateFilters(List<FieldExpression> fieldExpressionList,
Set<QueryFilter> metricFilters,
Set<QueryFilter> contextMetricFilters,
List<Expression> addConditions,
Set<String> removeFieldNames) {
if (CollectionUtils.isEmpty(metricFilters)) {
return;
}
for (QueryFilter dslQueryFilter : metricFilters) {
for (FieldExpression fieldExpression : fieldExpressionList) {
if (fieldExpression.getFieldName() != null
&& fieldExpression.getFieldName().contains(dslQueryFilter.getName())) {
removeFieldNames.add(dslQueryFilter.getName());
if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.EQUALS)) {
EqualsTo equalsTo = new EqualsTo();
addWhereFilters(dslQueryFilter, equalsTo, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.GREATER_THAN_EQUALS)) {
GreaterThanEquals greaterThanEquals = new GreaterThanEquals();
addWhereFilters(dslQueryFilter, greaterThanEquals, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.GREATER_THAN)) {
GreaterThan greaterThan = new GreaterThan();
addWhereFilters(dslQueryFilter, greaterThan, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.MINOR_THAN_EQUALS)) {
MinorThanEquals minorThanEquals = new MinorThanEquals();
addWhereFilters(dslQueryFilter, minorThanEquals, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.MINOR_THAN)) {
MinorThan minorThan = new MinorThan();
addWhereFilters(dslQueryFilter, minorThan, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.IN)) {
InExpression inExpression = new InExpression();
addWhereInFilters(dslQueryFilter, inExpression, contextMetricFilters, addConditions);
}
break;
}
}
}
}
// add in condition to sql where condition
private void addWhereInFilters(QueryFilter dslQueryFilter,
InExpression inExpression,
Set<QueryFilter> contextMetricFilters,
List<Expression> addConditions) {
Column column = new Column(dslQueryFilter.getName());
ParenthesedExpressionList parenthesedExpressionList = new ParenthesedExpressionList<>();
List<String> valueList = JsonUtil.toList(
JsonUtil.toString(dslQueryFilter.getValue()), String.class);
if (CollectionUtils.isEmpty(valueList)) {
return;
}
valueList.stream().forEach(o -> {
StringValue stringValue = new StringValue(o);
parenthesedExpressionList.add(stringValue);
});
inExpression.setLeftExpression(column);
inExpression.setRightExpression(parenthesedExpressionList);
addConditions.add(inExpression);
contextMetricFilters.stream().forEach(o -> {
if (o.getName().equals(dslQueryFilter.getName())) {
o.setValue(dslQueryFilter.getValue());
o.setOperator(dslQueryFilter.getOperator());
}
});
}
// add where filter
private <T extends ComparisonOperator> void addWhereFilters(QueryFilter dslQueryFilter,
T comparisonExpression,
Set<QueryFilter> contextMetricFilters,
List<Expression> addConditions) {
String columnName = dslQueryFilter.getName();
if (StringUtils.isNotBlank(dslQueryFilter.getFunction())) {
columnName = dslQueryFilter.getFunction() + "(" + dslQueryFilter.getName() + ")";
}
if (Objects.isNull(dslQueryFilter.getValue())) {
return;
}
Column column = new Column(columnName);
comparisonExpression.setLeftExpression(column);
if (StringUtils.isNumeric(dslQueryFilter.getValue().toString())) {
LongValue longValue = new LongValue(Long.parseLong(dslQueryFilter.getValue().toString()));
comparisonExpression.setRightExpression(longValue);
} else {
StringValue stringValue = new StringValue(dslQueryFilter.getValue().toString());
comparisonExpression.setRightExpression(stringValue);
}
addConditions.add(comparisonExpression);
contextMetricFilters.stream().forEach(o -> {
if (o.getName().equals(dslQueryFilter.getName())) {
o.setValue(dslQueryFilter.getValue());
o.setOperator(dslQueryFilter.getOperator());
}
});
}
private SemanticParseInfo getSemanticParseInfo(QueryDataReq queryData) {
SemanticParseInfo parseInfo = queryData.getParseInfo();
if (LLMSqlQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
return parseInfo;
}
if (CollectionUtils.isNotEmpty(queryData.getDimensions())) {
parseInfo.setDimensions(queryData.getDimensions());
}
if (CollectionUtils.isNotEmpty(queryData.getMetrics())) {
parseInfo.setMetrics(queryData.getMetrics());
}
if (CollectionUtils.isNotEmpty(queryData.getDimensionFilters())) {
parseInfo.setDimensionFilters(queryData.getDimensionFilters());
}
if (CollectionUtils.isNotEmpty(queryData.getMetricFilters())) {
parseInfo.setMetricFilters(queryData.getMetricFilters());
}
if (Objects.nonNull(queryData.getDateInfo())) {
parseInfo.setDateInfo(queryData.getDateInfo());
}
return parseInfo;
}
private void validFilter(Set<QueryFilter> filters) {
for (QueryFilter queryFilter : filters) {
if (Objects.isNull(queryFilter.getValue())) {
filters.remove(queryFilter);
}
if (queryFilter.getOperator().equals(FilterOperatorEnum.IN) && CollectionUtils.isEmpty(
JsonUtil.toList(JsonUtil.toString(queryFilter.getValue()), String.class))) {
filters.remove(queryFilter);
}
}
}
@Override
public Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception {
SemanticQueryResp semanticQueryResp = new SemanticQueryResp();
DimensionResp dimensionResp = schemaService.getDimension(dimensionValueReq.getElementID());
Set<Long> dataSetIds = dimensionValueReq.getDataSetIds();
dimensionValueReq.setModelId(dimensionResp.getModelId());
List<String> dimensionValues = getDimensionValues(dimensionValueReq, dataSetIds);
// if the search results is null,search dimensionValue from database
if (CollectionUtils.isEmpty(dimensionValues)) {
semanticQueryResp = queryDatabase(dimensionValueReq, user);
return semanticQueryResp;
}
List<QueryColumn> columns = new ArrayList<>();
QueryColumn queryColumn = new QueryColumn();
queryColumn.setNameEn(dimensionValueReq.getBizName());
queryColumn.setShowType("CATEGORY");
queryColumn.setAuthorized(true);
queryColumn.setType("CHAR");
columns.add(queryColumn);
List<Map<String, Object>> resultList = new ArrayList<>();
dimensionValues.stream().forEach(o -> {
Map<String, Object> map = new HashMap<>();
map.put(dimensionValueReq.getBizName(), o);
resultList.add(map);
});
semanticQueryResp.setColumns(columns);
semanticQueryResp.setResultList(resultList);
return semanticQueryResp;
}
private List<String> getDimensionValues(DimensionValueReq dimensionValueReq, Set<Long> dataSetIds) {
//if value is null ,then search from NATURE_TO_VALUES
if (StringUtils.isBlank(dimensionValueReq.getValue())) {
return SearchService.getDimensionValue(dimensionValueReq);
}
Map<Long, List<Long>> modelIdToDataSetIds = new HashMap<>();
modelIdToDataSetIds.put(dimensionValueReq.getModelId(), new ArrayList<>(dataSetIds));
//search from prefixSearch
List<HanlpMapResult> hanlpMapResultList = knowledgeBaseService.prefixSearch(dimensionValueReq.getValue(),
2000, modelIdToDataSetIds, dataSetIds);
HanlpHelper.transLetterOriginal(hanlpMapResultList);
return hanlpMapResultList.stream()
.filter(o -> {
for (String nature : o.getNatures()) {
Long elementID = NatureHelper.getElementID(nature);
if (dimensionValueReq.getElementID().equals(elementID)) {
return true;
}
}
return false;
})
.map(mapResult -> mapResult.getName())
.collect(Collectors.toList());
}
private SemanticQueryResp queryDatabase(DimensionValueReq dimensionValueReq, User user) {
QueryDimValueReq queryDimValueReq = new QueryDimValueReq();
queryDimValueReq.setValue(dimensionValueReq.getValue());
queryDimValueReq.setModelId(dimensionValueReq.getModelId());
queryDimValueReq.setDimensionBizName(dimensionValueReq.getBizName());
return semanticLayerService.queryDimValue(queryDimValueReq, user);
}
public void correct(QuerySqlReq querySqlReq, User user) {
SemanticParseInfo semanticParseInfo = correctSqlReq(querySqlReq, user);
querySqlReq.setSql(semanticParseInfo.getSqlInfo().getCorrectedS2SQL());

View File

@@ -4,6 +4,7 @@ import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TaskStatusEnum;
@@ -19,7 +20,8 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq;
import com.tencent.supersonic.headless.api.pojo.enums.SemanticType;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
@@ -27,11 +29,17 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.request.SchemaFilterReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
import com.tencent.supersonic.headless.api.pojo.response.ItemResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
import com.tencent.supersonic.headless.chat.knowledge.HanlpMapResult;
import com.tencent.supersonic.headless.chat.knowledge.KnowledgeBaseService;
import com.tencent.supersonic.headless.chat.knowledge.MapResult;
import com.tencent.supersonic.headless.chat.knowledge.SearchService;
import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper;
import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper;
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
import com.tencent.supersonic.headless.core.cache.QueryCache;
import com.tencent.supersonic.headless.core.executor.QueryExecutor;
@@ -57,6 +65,7 @@ import org.springframework.stereotype.Service;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
@@ -77,6 +86,7 @@ public class S2SemanticLayerService implements SemanticLayerService {
private final SchemaService schemaService;
private final SemanticTranslator semanticTranslator;
private final MetricDrillDownChecker metricDrillDownChecker;
private final KnowledgeBaseService knowledgeBaseService;
private QueryCache queryCache = ComponentFactory.getQueryCache();
private List<QueryExecutor> queryExecutors = ComponentFactory.getQueryExecutors();
@@ -88,7 +98,8 @@ public class S2SemanticLayerService implements SemanticLayerService {
DataSetService dataSetService,
SchemaService schemaService,
SemanticTranslator semanticTranslator,
MetricDrillDownChecker metricDrillDownChecker) {
MetricDrillDownChecker metricDrillDownChecker,
KnowledgeBaseService knowledgeBaseService) {
this.statUtils = statUtils;
this.queryUtils = queryUtils;
this.queryReqConverter = queryReqConverter;
@@ -97,6 +108,7 @@ public class S2SemanticLayerService implements SemanticLayerService {
this.schemaService = schemaService;
this.semanticTranslator = semanticTranslator;
this.metricDrillDownChecker = metricDrillDownChecker;
this.knowledgeBaseService = knowledgeBaseService;
}
public DataSetSchema getDataSetSchema(Long id) {
@@ -175,12 +187,74 @@ public class S2SemanticLayerService implements SemanticLayerService {
}
@Override
@SneakyThrows
public SemanticQueryResp queryDimValue(QueryDimValueReq queryDimValueReq, User user) {
QuerySqlReq querySqlReq = buildQuerySqlReq(queryDimValueReq);
public SemanticQueryResp queryDimensionValue(DimensionValueReq dimensionValueReq, User user) {
SemanticQueryResp semanticQueryResp = new SemanticQueryResp();
DimensionResp dimensionResp = getDimension(dimensionValueReq);
Set<Long> dataSetIds = dimensionValueReq.getDataSetIds();
dimensionValueReq.setModelId(dimensionResp.getModelId());
List<String> dimensionValues = getDimensionValuesFromDict(dimensionValueReq, dataSetIds);
// if the search results is null,search dimensionValue from database
if (CollectionUtils.isEmpty(dimensionValues)) {
semanticQueryResp = getDimensionValuesFromDb(dimensionValueReq, user);
return semanticQueryResp;
}
List<QueryColumn> columns = new ArrayList<>();
QueryColumn queryColumn = new QueryColumn();
queryColumn.setNameEn(dimensionValueReq.getBizName());
queryColumn.setShowType(SemanticType.CATEGORY.name());
queryColumn.setAuthorized(true);
queryColumn.setType("CHAR");
columns.add(queryColumn);
List<Map<String, Object>> resultList = new ArrayList<>();
dimensionValues.stream().forEach(o -> {
Map<String, Object> map = new HashMap<>();
map.put(dimensionValueReq.getBizName(), o);
resultList.add(map);
});
semanticQueryResp.setColumns(columns);
semanticQueryResp.setResultList(resultList);
return semanticQueryResp;
}
private List<String> getDimensionValuesFromDict(DimensionValueReq dimensionValueReq, Set<Long> dataSetIds) {
//if value is null ,then search from NATURE_TO_VALUES
if (StringUtils.isBlank(dimensionValueReq.getValue())) {
return SearchService.getDimensionValue(dimensionValueReq);
}
Map<Long, List<Long>> modelIdToDataSetIds = new HashMap<>();
modelIdToDataSetIds.put(dimensionValueReq.getModelId(), new ArrayList<>(dataSetIds));
//search from prefixSearch
List<HanlpMapResult> hanlpMapResultList = knowledgeBaseService.prefixSearch(dimensionValueReq.getValue(),
2000, modelIdToDataSetIds, dataSetIds);
HanlpHelper.transLetterOriginal(hanlpMapResultList);
return hanlpMapResultList.stream()
.filter(o -> {
for (String nature : o.getNatures()) {
Long elementID = NatureHelper.getElementID(nature);
if (dimensionValueReq.getElementID().equals(elementID)) {
return true;
}
}
return false;
})
.map(MapResult::getName)
.collect(Collectors.toList());
}
private SemanticQueryResp getDimensionValuesFromDb(DimensionValueReq dimensionValueReq, User user) {
QuerySqlReq querySqlReq = buildQuerySqlReq(dimensionValueReq);
return queryByReq(querySqlReq, user);
}
private DimensionResp getDimension(DimensionValueReq dimensionValueReq) {
DimensionResp dimensionResp = schemaService.getDimension(dimensionValueReq.getElementID());
if (dimensionResp == null) {
return schemaService.getDimension(dimensionValueReq.getBizName(),
dimensionValueReq.getModelId());
}
return dimensionResp;
}
public EntityInfo getEntityInfo(SemanticParseInfo parseInfo, DataSetSchema dataSetSchema, User user) {
if (parseInfo != null && parseInfo.getDataSetId() != null && parseInfo.getDataSetId() > 0) {
EntityInfo entityInfo = getEntityBasicInfo(dataSetSchema);
@@ -291,10 +365,10 @@ public class S2SemanticLayerService implements SemanticLayerService {
return schemaFilterReq;
}
private QuerySqlReq buildQuerySqlReq(QueryDimValueReq queryDimValueReq) {
private QuerySqlReq buildQuerySqlReq(DimensionValueReq queryDimValueReq) {
QuerySqlReq querySqlReq = new QuerySqlReq();
List<ModelResp> modelResps = schemaService.getModelList(Lists.newArrayList(queryDimValueReq.getModelId()));
DimensionResp dimensionResp = schemaService.getDimension(queryDimValueReq.getDimensionBizName(),
DimensionResp dimensionResp = schemaService.getDimension(queryDimValueReq.getBizName(),
queryDimValueReq.getModelId());
ModelResp modelResp = modelResps.get(0);
String sql = String.format("select distinct %s from %s where 1=1",
@@ -306,7 +380,7 @@ public class S2SemanticLayerService implements SemanticLayerService {
queryDimValueReq.getDateInfo().getEndDate());
}
if (StringUtils.isNotBlank(queryDimValueReq.getValue())) {
sql += " AND " + queryDimValueReq.getDimensionBizName() + " LIKE '%" + queryDimValueReq.getValue() + "%'";
sql += " AND " + queryDimValueReq.getBizName() + " LIKE '%" + queryDimValueReq.getValue() + "%'";
}
querySqlReq.setModelIds(Sets.newHashSet(queryDimValueReq.getModelId()));
querySqlReq.setSql(sql);

View File

@@ -7,15 +7,15 @@ import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.common.pojo.enums.SensitiveLevelEnum;
import com.tencent.supersonic.headless.api.pojo.DimValueMap;
import com.tencent.supersonic.headless.api.pojo.request.DimensionReq;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq;
import com.tencent.supersonic.headless.api.pojo.request.PageDimensionReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq;
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
import com.tencent.supersonic.headless.server.pojo.DimensionFilter;
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
import com.tencent.supersonic.headless.server.web.service.DimensionService;
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.DeleteMapping;
import org.springframework.web.bind.annotation.GetMapping;
@@ -121,11 +121,11 @@ public class DimensionController {
}
@PostMapping("/queryDimValue")
public SemanticQueryResp queryDimValue(@RequestBody QueryDimValueReq queryDimValueReq,
public SemanticQueryResp queryDimValue(@RequestBody DimensionValueReq dimensionValueReq,
HttpServletRequest request,
HttpServletResponse response) {
User user = UserHolder.findUser(request, response);
return queryService.queryDimValue(queryDimValueReq, user);
return queryService.queryDimensionValue(dimensionValueReq, user);
}
@DeleteMapping("deleteDimension/{id}")

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.headless;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
@@ -10,11 +10,11 @@ public class QueryDimensionTest extends BaseTest {
@Test
public void testQueryDimValue() {
QueryDimValueReq queryDimValueReq = new QueryDimValueReq();
DimensionValueReq queryDimValueReq = new DimensionValueReq();
queryDimValueReq.setModelId(1L);
queryDimValueReq.setDimensionBizName("department");
queryDimValueReq.setBizName("department");
SemanticQueryResp queryResp = semanticLayerService.queryDimValue(queryDimValueReq, User.getFakeUser());
SemanticQueryResp queryResp = semanticLayerService.queryDimensionValue(queryDimValueReq, User.getFakeUser());
Assert.assertNotNull(queryResp.getResultList());
Assert.assertEquals(4, queryResp.getResultList().size());
}