mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
(improvement)(Headless) Refactor ChatLayerService and SemanticLayerService (#1404)
Co-authored-by: lxwcodemonkey
This commit is contained in:
@@ -13,29 +13,70 @@ import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor;
|
||||
import com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatContextService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatManageService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
||||
import com.tencent.supersonic.chat.server.util.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||
import com.tencent.supersonic.common.jsqlparser.FieldExpression;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.BeanMapper;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.EntityInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
|
||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
|
||||
import com.tencent.supersonic.headless.server.facade.service.RetrieveService;
|
||||
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.expression.LongValue;
|
||||
import net.sf.jsqlparser.expression.StringValue;
|
||||
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
|
||||
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
|
||||
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
|
||||
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
|
||||
import net.sf.jsqlparser.expression.operators.relational.InExpression;
|
||||
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
|
||||
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
|
||||
import net.sf.jsqlparser.expression.operators.relational.ParenthesedExpressionList;
|
||||
import net.sf.jsqlparser.schema.Column;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
|
||||
@Slf4j
|
||||
@@ -47,13 +88,11 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
@Autowired
|
||||
private ChatLayerService chatLayerService;
|
||||
@Autowired
|
||||
private SemanticLayerService semanticLayerService;
|
||||
@Autowired
|
||||
private RetrieveService retrieveService;
|
||||
@Autowired
|
||||
private AgentService agentService;
|
||||
@Autowired
|
||||
private SemanticLayerService semanticLayerService;
|
||||
@Autowired
|
||||
private ChatContextService chatContextService;
|
||||
|
||||
private List<ChatQueryParser> chatQueryParsers = ComponentFactory.getChatParsers();
|
||||
private List<ChatQueryExecutor> chatQueryExecutors = ComponentFactory.getChatExecutors();
|
||||
@@ -160,23 +199,328 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
return executeContext;
|
||||
}
|
||||
|
||||
//mainly used for executing after revising filters,for example:"fans_cnt>=100000"->"fans_cnt>500000",
|
||||
//"style='流行'"->"style in ['流行','爱国']"
|
||||
@Override
|
||||
public Object queryData(ChatQueryDataReq chatQueryDataReq, User user) throws Exception {
|
||||
Integer parseId = chatQueryDataReq.getParseId();
|
||||
SemanticParseInfo parseInfo = chatManageService.getParseInfo(
|
||||
chatQueryDataReq.getQueryId(), parseId);
|
||||
QueryDataReq queryData = new QueryDataReq();
|
||||
BeanMapper.mapper(chatQueryDataReq, queryData);
|
||||
queryData.setParseInfo(parseInfo);
|
||||
return chatLayerService.executeDirectQuery(queryData, user);
|
||||
parseInfo = mergeSemanticParseInfo(parseInfo, chatQueryDataReq);
|
||||
DataSetSchema dataSetSchema = semanticLayerService.getDataSetSchema(parseInfo.getDataSetId());
|
||||
|
||||
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
|
||||
semanticQuery.setParseInfo(parseInfo);
|
||||
|
||||
List<String> fields = new ArrayList<>();
|
||||
if (Objects.nonNull(parseInfo.getSqlInfo())
|
||||
&& StringUtils.isNotBlank(parseInfo.getSqlInfo().getCorrectedS2SQL())) {
|
||||
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
fields = SqlSelectHelper.getAllFields(correctorSql);
|
||||
}
|
||||
if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())
|
||||
&& checkMetricReplace(fields, chatQueryDataReq.getMetrics())) {
|
||||
//replace metrics
|
||||
log.info("llm begin replace metrics!");
|
||||
SchemaElement metricToReplace = chatQueryDataReq.getMetrics().iterator().next();
|
||||
replaceMetrics(parseInfo, metricToReplace);
|
||||
} else if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())) {
|
||||
log.info("llm begin revise filters!");
|
||||
String correctorSql = reviseCorrectS2SQL(chatQueryDataReq, parseInfo);
|
||||
parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql);
|
||||
semanticQuery.setParseInfo(parseInfo);
|
||||
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
|
||||
SemanticTranslateResp explain = semanticLayerService.translate(semanticQueryReq, user);
|
||||
parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL());
|
||||
} else {
|
||||
log.info("rule begin replace metrics and revise filters!");
|
||||
//remove unvalid filters
|
||||
validFilter(semanticQuery.getParseInfo().getDimensionFilters());
|
||||
validFilter(semanticQuery.getParseInfo().getMetricFilters());
|
||||
//init s2sql
|
||||
semanticQuery.initS2Sql(dataSetSchema, user);
|
||||
}
|
||||
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
|
||||
QueryResult queryResult = doExecution(semanticQueryReq, semanticQuery.getParseInfo(), user);
|
||||
queryResult.setChatContext(semanticQuery.getParseInfo());
|
||||
SemanticLayerService semanticService = ContextUtils.getBean(SemanticLayerService.class);
|
||||
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, user);
|
||||
queryResult.setEntityInfo(entityInfo);
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
private boolean checkMetricReplace(List<String> oriFields, Set<SchemaElement> metrics) {
|
||||
if (CollectionUtils.isEmpty(oriFields)) {
|
||||
return false;
|
||||
}
|
||||
if (CollectionUtils.isEmpty(metrics)) {
|
||||
return false;
|
||||
}
|
||||
List<String> metricNames = metrics.stream().map(SchemaElement::getName).collect(Collectors.toList());
|
||||
return !oriFields.containsAll(metricNames);
|
||||
}
|
||||
|
||||
private String reviseCorrectS2SQL(ChatQueryDataReq queryData, SemanticParseInfo parseInfo) {
|
||||
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
|
||||
Map<String, Map<String, String>> havingFiledNameToValueMap = new HashMap<>();
|
||||
|
||||
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
log.info("correctorSql before replacing:{}", correctorSql);
|
||||
// get where filter and having filter
|
||||
List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctorSql);
|
||||
List<FieldExpression> havingExpressionList = SqlSelectHelper.getHavingExpressions(correctorSql);
|
||||
List<Expression> addWhereConditions = new ArrayList<>();
|
||||
List<Expression> addHavingConditions = new ArrayList<>();
|
||||
Set<String> removeWhereFieldNames = new HashSet<>();
|
||||
Set<String> removeHavingFieldNames = new HashSet<>();
|
||||
// replace where filter
|
||||
updateFilters(whereExpressionList, queryData.getDimensionFilters(),
|
||||
parseInfo.getDimensionFilters(), addWhereConditions, removeWhereFieldNames);
|
||||
updateDateInfo(queryData, parseInfo, filedNameToValueMap,
|
||||
whereExpressionList, addWhereConditions, removeWhereFieldNames);
|
||||
correctorSql = SqlReplaceHelper.replaceValue(correctorSql, filedNameToValueMap);
|
||||
correctorSql = SqlRemoveHelper.removeWhereCondition(correctorSql, removeWhereFieldNames);
|
||||
// replace having filter
|
||||
updateFilters(havingExpressionList, queryData.getDimensionFilters(),
|
||||
parseInfo.getDimensionFilters(), addHavingConditions, removeHavingFieldNames);
|
||||
correctorSql = SqlReplaceHelper.replaceHavingValue(correctorSql, havingFiledNameToValueMap);
|
||||
correctorSql = SqlRemoveHelper.removeHavingCondition(correctorSql, removeHavingFieldNames);
|
||||
|
||||
correctorSql = SqlAddHelper.addWhere(correctorSql, addWhereConditions);
|
||||
correctorSql = SqlAddHelper.addHaving(correctorSql, addHavingConditions);
|
||||
log.info("correctorSql after replacing:{}", correctorSql);
|
||||
return correctorSql;
|
||||
}
|
||||
|
||||
private void replaceMetrics(SemanticParseInfo parseInfo, SchemaElement metric) {
|
||||
List<String> oriMetrics = parseInfo.getMetrics().stream()
|
||||
.map(SchemaElement::getName).collect(Collectors.toList());
|
||||
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
log.info("before replaceMetrics:{}", correctorSql);
|
||||
log.info("filteredMetrics:{},metrics:{}", oriMetrics, metric);
|
||||
Map<String, Pair<String, String>> fieldMap = new HashMap<>();
|
||||
if (!CollectionUtils.isEmpty(oriMetrics) && !oriMetrics.contains(metric.getName())) {
|
||||
fieldMap.put(oriMetrics.get(0), Pair.of(metric.getName(), metric.getDefaultAgg()));
|
||||
correctorSql = SqlReplaceHelper.replaceAggFields(correctorSql, fieldMap);
|
||||
}
|
||||
log.info("after replaceMetrics:{}", correctorSql);
|
||||
parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql);
|
||||
}
|
||||
|
||||
private QueryResult doExecution(SemanticQueryReq semanticQueryReq,
|
||||
SemanticParseInfo parseInfo, User user) throws Exception {
|
||||
SemanticQueryResp queryResp = semanticLayerService.queryByReq(semanticQueryReq, user);
|
||||
QueryResult queryResult = new QueryResult();
|
||||
if (queryResp != null) {
|
||||
queryResult.setQueryAuthorization(queryResp.getQueryAuthorization());
|
||||
}
|
||||
|
||||
String sql = queryResp == null ? null : queryResp.getSql();
|
||||
List<Map<String, Object>> resultList = queryResp == null ? new ArrayList<>()
|
||||
: queryResp.getResultList();
|
||||
List<QueryColumn> columns = queryResp == null ? new ArrayList<>() : queryResp.getColumns();
|
||||
queryResult.setQuerySql(sql);
|
||||
queryResult.setQueryResults(resultList);
|
||||
queryResult.setQueryColumns(columns);
|
||||
queryResult.setQueryMode(parseInfo.getQueryMode());
|
||||
queryResult.setQueryState(QueryState.SUCCESS);
|
||||
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
private void updateDateInfo(ChatQueryDataReq queryData, SemanticParseInfo parseInfo,
|
||||
Map<String, Map<String, String>> filedNameToValueMap,
|
||||
List<FieldExpression> fieldExpressionList,
|
||||
List<Expression> addConditions,
|
||||
Set<String> removeFieldNames) {
|
||||
if (Objects.isNull(queryData.getDateInfo())) {
|
||||
return;
|
||||
}
|
||||
if (queryData.getDateInfo().getUnit() > 1) {
|
||||
queryData.getDateInfo().setStartDate(DateUtils.getBeforeDate(queryData.getDateInfo().getUnit() + 1));
|
||||
queryData.getDateInfo().setEndDate(DateUtils.getBeforeDate(1));
|
||||
}
|
||||
// startDate equals to endDate
|
||||
for (FieldExpression fieldExpression : fieldExpressionList) {
|
||||
if (TimeDimensionEnum.DAY.getChName().equals(fieldExpression.getFieldName())) {
|
||||
// first remove,then add
|
||||
removeFieldNames.add(TimeDimensionEnum.DAY.getChName());
|
||||
GreaterThanEquals greaterThanEquals = new GreaterThanEquals();
|
||||
addTimeFilters(queryData.getDateInfo().getStartDate(), greaterThanEquals, addConditions);
|
||||
MinorThanEquals minorThanEquals = new MinorThanEquals();
|
||||
addTimeFilters(queryData.getDateInfo().getEndDate(), minorThanEquals, addConditions);
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (FieldExpression fieldExpression : fieldExpressionList) {
|
||||
for (QueryFilter queryFilter : queryData.getDimensionFilters()) {
|
||||
if (queryFilter.getOperator().equals(FilterOperatorEnum.LIKE)
|
||||
&& FilterOperatorEnum.LIKE.getValue().toLowerCase().equals(
|
||||
fieldExpression.getOperator().toLowerCase())) {
|
||||
Map<String, String> replaceMap = new HashMap<>();
|
||||
String preValue = fieldExpression.getFieldValue().toString();
|
||||
String curValue = queryFilter.getValue().toString();
|
||||
if (preValue.startsWith("%")) {
|
||||
curValue = "%" + curValue;
|
||||
}
|
||||
if (preValue.endsWith("%")) {
|
||||
curValue = curValue + "%";
|
||||
}
|
||||
replaceMap.put(preValue, curValue);
|
||||
filedNameToValueMap.put(fieldExpression.getFieldName(), replaceMap);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
parseInfo.setDateInfo(queryData.getDateInfo());
|
||||
}
|
||||
|
||||
private <T extends ComparisonOperator> void addTimeFilters(String date,
|
||||
T comparisonExpression,
|
||||
List<Expression> addConditions) {
|
||||
Column column = new Column(TimeDimensionEnum.DAY.getChName());
|
||||
StringValue stringValue = new StringValue(date);
|
||||
comparisonExpression.setLeftExpression(column);
|
||||
comparisonExpression.setRightExpression(stringValue);
|
||||
addConditions.add(comparisonExpression);
|
||||
}
|
||||
|
||||
private void updateFilters(List<FieldExpression> fieldExpressionList,
|
||||
Set<QueryFilter> metricFilters,
|
||||
Set<QueryFilter> contextMetricFilters,
|
||||
List<Expression> addConditions,
|
||||
Set<String> removeFieldNames) {
|
||||
if (org.apache.commons.collections.CollectionUtils.isEmpty(metricFilters)) {
|
||||
return;
|
||||
}
|
||||
for (QueryFilter dslQueryFilter : metricFilters) {
|
||||
for (FieldExpression fieldExpression : fieldExpressionList) {
|
||||
if (fieldExpression.getFieldName() != null
|
||||
&& fieldExpression.getFieldName().contains(dslQueryFilter.getName())) {
|
||||
removeFieldNames.add(dslQueryFilter.getName());
|
||||
if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.EQUALS)) {
|
||||
EqualsTo equalsTo = new EqualsTo();
|
||||
addWhereFilters(dslQueryFilter, equalsTo, contextMetricFilters, addConditions);
|
||||
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.GREATER_THAN_EQUALS)) {
|
||||
GreaterThanEquals greaterThanEquals = new GreaterThanEquals();
|
||||
addWhereFilters(dslQueryFilter, greaterThanEquals, contextMetricFilters, addConditions);
|
||||
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.GREATER_THAN)) {
|
||||
GreaterThan greaterThan = new GreaterThan();
|
||||
addWhereFilters(dslQueryFilter, greaterThan, contextMetricFilters, addConditions);
|
||||
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.MINOR_THAN_EQUALS)) {
|
||||
MinorThanEquals minorThanEquals = new MinorThanEquals();
|
||||
addWhereFilters(dslQueryFilter, minorThanEquals, contextMetricFilters, addConditions);
|
||||
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.MINOR_THAN)) {
|
||||
MinorThan minorThan = new MinorThan();
|
||||
addWhereFilters(dslQueryFilter, minorThan, contextMetricFilters, addConditions);
|
||||
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.IN)) {
|
||||
InExpression inExpression = new InExpression();
|
||||
addWhereInFilters(dslQueryFilter, inExpression, contextMetricFilters, addConditions);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// add in condition to sql where condition
|
||||
private void addWhereInFilters(QueryFilter dslQueryFilter,
|
||||
InExpression inExpression,
|
||||
Set<QueryFilter> contextMetricFilters,
|
||||
List<Expression> addConditions) {
|
||||
Column column = new Column(dslQueryFilter.getName());
|
||||
ParenthesedExpressionList parenthesedExpressionList = new ParenthesedExpressionList<>();
|
||||
List<String> valueList = JsonUtil.toList(
|
||||
JsonUtil.toString(dslQueryFilter.getValue()), String.class);
|
||||
if (org.apache.commons.collections.CollectionUtils.isEmpty(valueList)) {
|
||||
return;
|
||||
}
|
||||
valueList.stream().forEach(o -> {
|
||||
StringValue stringValue = new StringValue(o);
|
||||
parenthesedExpressionList.add(stringValue);
|
||||
});
|
||||
inExpression.setLeftExpression(column);
|
||||
inExpression.setRightExpression(parenthesedExpressionList);
|
||||
addConditions.add(inExpression);
|
||||
contextMetricFilters.stream().forEach(o -> {
|
||||
if (o.getName().equals(dslQueryFilter.getName())) {
|
||||
o.setValue(dslQueryFilter.getValue());
|
||||
o.setOperator(dslQueryFilter.getOperator());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// add where filter
|
||||
private <T extends ComparisonOperator> void addWhereFilters(QueryFilter dslQueryFilter,
|
||||
T comparisonExpression,
|
||||
Set<QueryFilter> contextMetricFilters,
|
||||
List<Expression> addConditions) {
|
||||
String columnName = dslQueryFilter.getName();
|
||||
if (StringUtils.isNotBlank(dslQueryFilter.getFunction())) {
|
||||
columnName = dslQueryFilter.getFunction() + "(" + dslQueryFilter.getName() + ")";
|
||||
}
|
||||
if (Objects.isNull(dslQueryFilter.getValue())) {
|
||||
return;
|
||||
}
|
||||
Column column = new Column(columnName);
|
||||
comparisonExpression.setLeftExpression(column);
|
||||
if (StringUtils.isNumeric(dslQueryFilter.getValue().toString())) {
|
||||
LongValue longValue = new LongValue(Long.parseLong(dslQueryFilter.getValue().toString()));
|
||||
comparisonExpression.setRightExpression(longValue);
|
||||
} else {
|
||||
StringValue stringValue = new StringValue(dslQueryFilter.getValue().toString());
|
||||
comparisonExpression.setRightExpression(stringValue);
|
||||
}
|
||||
addConditions.add(comparisonExpression);
|
||||
contextMetricFilters.stream().forEach(o -> {
|
||||
if (o.getName().equals(dslQueryFilter.getName())) {
|
||||
o.setValue(dslQueryFilter.getValue());
|
||||
o.setOperator(dslQueryFilter.getOperator());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private SemanticParseInfo mergeSemanticParseInfo(SemanticParseInfo parseInfo,
|
||||
ChatQueryDataReq queryData) {
|
||||
if (LLMSqlQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
|
||||
return parseInfo;
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(queryData.getDimensions())) {
|
||||
parseInfo.setDimensions(queryData.getDimensions());
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(queryData.getMetrics())) {
|
||||
parseInfo.setMetrics(queryData.getMetrics());
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(queryData.getDimensionFilters())) {
|
||||
parseInfo.setDimensionFilters(queryData.getDimensionFilters());
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(queryData.getMetricFilters())) {
|
||||
parseInfo.setMetricFilters(queryData.getMetricFilters());
|
||||
}
|
||||
if (Objects.nonNull(queryData.getDateInfo())) {
|
||||
parseInfo.setDateInfo(queryData.getDateInfo());
|
||||
}
|
||||
return parseInfo;
|
||||
}
|
||||
|
||||
private void validFilter(Set<QueryFilter> filters) {
|
||||
for (QueryFilter queryFilter : filters) {
|
||||
if (Objects.isNull(queryFilter.getValue())) {
|
||||
filters.remove(queryFilter);
|
||||
}
|
||||
if (queryFilter.getOperator().equals(FilterOperatorEnum.IN) && CollectionUtils.isEmpty(
|
||||
JsonUtil.toList(JsonUtil.toString(queryFilter.getValue()), String.class))) {
|
||||
filters.remove(queryFilter);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception {
|
||||
public Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) {
|
||||
Integer agentId = dimensionValueReq.getAgentId();
|
||||
Agent agent = agentService.getAgent(agentId);
|
||||
dimensionValueReq.setDataSetIds(agent.getDataSetIds());
|
||||
return chatLayerService.queryDimensionValue(dimensionValueReq, user);
|
||||
return semanticLayerService.queryDimensionValue(dimensionValueReq, user);
|
||||
}
|
||||
|
||||
public void saveQueryResult(ChatExecuteReq chatExecuteReq, QueryResult queryResult) {
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
package com.tencent.supersonic.headless.api.pojo;
|
||||
|
||||
import lombok.Data;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.Data;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
@Data
|
||||
public class DataSetSchema {
|
||||
@@ -57,6 +59,14 @@ public class DataSetSchema {
|
||||
}
|
||||
}
|
||||
|
||||
public Map<String, String> getBizNameToName() {
|
||||
List<SchemaElement> allElements = new ArrayList<>();
|
||||
allElements.addAll(getDimensions());
|
||||
allElements.addAll(getMetrics());
|
||||
return allElements.stream()
|
||||
.collect(Collectors.toMap(SchemaElement::getBizName, SchemaElement::getName, (k1, k2) -> k1));
|
||||
}
|
||||
|
||||
public TimeDefaultConfig getTagTypeTimeDefaultConfig() {
|
||||
if (queryConfig == null) {
|
||||
return null;
|
||||
|
||||
@@ -150,14 +150,6 @@ public class SemanticSchema implements Serializable {
|
||||
return dataSets;
|
||||
}
|
||||
|
||||
public Map<String, String> getBizNameToName(Long dataSetId) {
|
||||
List<SchemaElement> allElements = new ArrayList<>();
|
||||
allElements.addAll(getDimensions(dataSetId));
|
||||
allElements.addAll(getMetrics(dataSetId));
|
||||
return allElements.stream()
|
||||
.collect(Collectors.toMap(SchemaElement::getBizName, SchemaElement::getName, (k1, k2) -> k1));
|
||||
}
|
||||
|
||||
public Map<Long, DataSetSchema> getDataSetSchemaMap() {
|
||||
if (CollectionUtils.isEmpty(dataSetSchemaList)) {
|
||||
return new HashMap<>();
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package com.tencent.supersonic.headless.api.pojo.request;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import lombok.Data;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import javax.validation.constraints.NotNull;
|
||||
import java.util.Set;
|
||||
|
||||
@@ -21,4 +23,11 @@ public class DimensionValueReq {
|
||||
|
||||
private Set<Long> dataSetIds;
|
||||
|
||||
private DateConf dateInfo = new DateConf();
|
||||
|
||||
private String dimensionBizName;
|
||||
|
||||
public String getBizName() {
|
||||
return StringUtils.isBlank(bizName) ? dimensionBizName : bizName;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
package com.tencent.supersonic.headless.api.pojo.request;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
public class QueryDimValueReq {
|
||||
|
||||
private Long modelId;
|
||||
private String dimensionBizName;
|
||||
private String value;
|
||||
private DateConf dateInfo = new DateConf();
|
||||
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
@@ -35,7 +36,10 @@ public class QueryTypeParser implements SemanticParser {
|
||||
|
||||
for (SemanticQuery semanticQuery : candidateQueries) {
|
||||
// 1.init S2SQL
|
||||
semanticQuery.initS2Sql(chatQueryContext.getSemanticSchema(), user);
|
||||
Long dataSetId = semanticQuery.getParseInfo().getDataSetId();
|
||||
DataSetSchema dataSetSchema = chatQueryContext.getSemanticSchema()
|
||||
.getDataSetSchemaMap().get(dataSetId);
|
||||
semanticQuery.initS2Sql(dataSetSchema, user);
|
||||
// 2.set queryType
|
||||
QueryType queryType = getQueryType(chatQueryContext, semanticQuery);
|
||||
semanticQuery.getParseInfo().setQueryType(queryType);
|
||||
|
||||
@@ -6,8 +6,8 @@ import com.tencent.supersonic.common.pojo.Filter;
|
||||
import com.tencent.supersonic.common.pojo.Order;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
||||
@@ -43,8 +43,8 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
|
||||
return QueryReqBuilder.buildStructReq(parseInfo);
|
||||
}
|
||||
|
||||
protected void convertBizNameToName(SemanticSchema semanticSchema, QueryStructReq queryStructReq) {
|
||||
Map<String, String> bizNameToName = semanticSchema.getBizNameToName(queryStructReq.getDataSetId());
|
||||
protected void convertBizNameToName(DataSetSchema dataSetSchema, QueryStructReq queryStructReq) {
|
||||
Map<String, String> bizNameToName = dataSetSchema.getBizNameToName();
|
||||
bizNameToName.putAll(TimeDimensionEnum.getNameToNameMap());
|
||||
|
||||
List<Order> orders = queryStructReq.getOrders();
|
||||
@@ -74,14 +74,14 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
|
||||
}
|
||||
}
|
||||
|
||||
protected void initS2SqlByStruct(SemanticSchema semanticSchema) {
|
||||
protected void initS2SqlByStruct(DataSetSchema dataSetSchema) {
|
||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||
boolean s2sqlEnable = Boolean.valueOf(parserConfig.getParameterValue(PARSER_S2SQL_ENABLE));
|
||||
if (!s2sqlEnable) {
|
||||
return;
|
||||
}
|
||||
QueryStructReq queryStructReq = convertQueryStruct();
|
||||
convertBizNameToName(semanticSchema, queryStructReq);
|
||||
convertBizNameToName(dataSetSchema, queryStructReq);
|
||||
QuerySqlReq querySQLReq = queryStructReq.convert();
|
||||
parseInfo.getSqlInfo().setParsedS2SQL(querySQLReq.getSql());
|
||||
parseInfo.getSqlInfo().setCorrectedS2SQL(querySQLReq.getSql());
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package com.tencent.supersonic.headless.chat.query;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import org.apache.calcite.sql.parser.SqlParseException;
|
||||
|
||||
@@ -15,7 +15,7 @@ public interface SemanticQuery {
|
||||
|
||||
SemanticQueryReq buildSemanticQueryReq() throws SqlParseException;
|
||||
|
||||
void initS2Sql(SemanticSchema semanticSchema, User user);
|
||||
void initS2Sql(DataSetSchema dataSetSchema, User user);
|
||||
|
||||
SemanticParseInfo getParseInfo();
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
package com.tencent.supersonic.headless.chat.query.llm.s2sql;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
|
||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.LLMSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@@ -31,7 +31,7 @@ public class LLMSqlQuery extends LLMSemanticQuery {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initS2Sql(SemanticSchema semanticSchema, User user) {
|
||||
public void initS2Sql(DataSetSchema dataSetSchema, User user) {
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
sqlInfo.setCorrectedS2SQL(sqlInfo.getParsedS2SQL());
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.chat.query.rule;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
@@ -13,9 +14,13 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
|
||||
import com.tencent.supersonic.headless.chat.query.BaseSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
|
||||
import lombok.ToString;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
@@ -24,9 +29,6 @@ import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.ToString;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
@Slf4j
|
||||
@ToString
|
||||
@@ -44,8 +46,8 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initS2Sql(SemanticSchema semanticSchema, User user) {
|
||||
initS2SqlByStruct(semanticSchema);
|
||||
public void initS2Sql(DataSetSchema dataSetSchema, User user) {
|
||||
initS2SqlByStruct(dataSetSchema);
|
||||
}
|
||||
|
||||
public void fillParseInfo(ChatQueryContext chatQueryContext) {
|
||||
|
||||
@@ -2,15 +2,12 @@ package com.tencent.supersonic.headless.server.facade.service;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlEvaluation;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryMapReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
|
||||
/***dd
|
||||
* SemanticLayerService for query and search
|
||||
@@ -21,10 +18,6 @@ public interface ChatLayerService {
|
||||
|
||||
ParseResp performParsing(QueryNLReq queryNLReq);
|
||||
|
||||
QueryResult executeDirectQuery(QueryDataReq queryData, User user) throws Exception;
|
||||
|
||||
Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception;
|
||||
|
||||
MapInfoResp map(QueryMapReq queryMapReq);
|
||||
|
||||
void correct(QuerySqlReq querySqlReq, User user);
|
||||
|
||||
@@ -4,11 +4,11 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.EntityInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ItemResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@@ -23,7 +23,7 @@ public interface SemanticLayerService {
|
||||
|
||||
SemanticQueryResp queryByReq(SemanticQueryReq queryReq, User user) throws Exception;
|
||||
|
||||
SemanticQueryResp queryDimValue(QueryDimValueReq queryDimValueReq, User user);
|
||||
SemanticQueryResp queryDimensionValue(DimensionValueReq dimensionValueReq, User user);
|
||||
|
||||
EntityInfo getEntityInfo(SemanticParseInfo parseInfo, DataSetSchema dataSetSchema, User user);
|
||||
|
||||
|
||||
@@ -2,20 +2,8 @@ package com.tencent.supersonic.headless.server.facade.service.impl;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.jsqlparser.FieldExpression;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.EntityInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
@@ -25,60 +13,27 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlEvaluation;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryMapReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DataSetMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.corrector.GrammarCorrector;
|
||||
import com.tencent.supersonic.headless.chat.corrector.SchemaCorrector;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.HanlpMapResult;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.KnowledgeBaseService;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.SearchService;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.builder.BaseWordBuilder;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper;
|
||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
|
||||
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
||||
import com.tencent.supersonic.headless.server.utils.ChatWorkflowEngine;
|
||||
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
|
||||
import com.tencent.supersonic.headless.server.utils.ChatWorkflowEngine;
|
||||
import com.tencent.supersonic.headless.server.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.headless.server.web.service.DataSetService;
|
||||
import com.tencent.supersonic.headless.server.web.service.SchemaService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.expression.LongValue;
|
||||
import net.sf.jsqlparser.expression.StringValue;
|
||||
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
|
||||
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
|
||||
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
|
||||
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
|
||||
import net.sf.jsqlparser.expression.operators.relational.InExpression;
|
||||
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
|
||||
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
|
||||
import net.sf.jsqlparser.expression.operators.relational.ParenthesedExpressionList;
|
||||
import net.sf.jsqlparser.schema.Column;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
@@ -86,7 +41,6 @@ import org.springframework.stereotype.Service;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
@@ -97,13 +51,9 @@ import java.util.stream.Collectors;
|
||||
@Service
|
||||
@Slf4j
|
||||
public class S2ChatLayerService implements ChatLayerService {
|
||||
@Autowired
|
||||
private SemanticLayerService semanticLayerService;
|
||||
@Autowired
|
||||
private SchemaService schemaService;
|
||||
@Autowired
|
||||
private KnowledgeBaseService knowledgeBaseService;
|
||||
@Autowired
|
||||
private DataSetService dataSetService;
|
||||
@Autowired
|
||||
private ChatWorkflowEngine chatWorkflowEngine;
|
||||
@@ -166,386 +116,6 @@ public class S2ChatLayerService implements ChatLayerService {
|
||||
return queryCtx;
|
||||
}
|
||||
|
||||
//mainly used for executing after revising filters,for example:"fans_cnt>=100000"->"fans_cnt>500000",
|
||||
//"style='流行'"->"style in ['流行','爱国']"
|
||||
@Override
|
||||
public QueryResult executeDirectQuery(QueryDataReq queryData, User user) throws Exception {
|
||||
SemanticParseInfo parseInfo = getSemanticParseInfo(queryData);
|
||||
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
|
||||
|
||||
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
|
||||
semanticQuery.setParseInfo(parseInfo);
|
||||
|
||||
List<String> fields = new ArrayList<>();
|
||||
if (Objects.nonNull(parseInfo.getSqlInfo())
|
||||
&& StringUtils.isNotBlank(parseInfo.getSqlInfo().getCorrectedS2SQL())) {
|
||||
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
fields = SqlSelectHelper.getAllFields(correctorSql);
|
||||
}
|
||||
if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())
|
||||
&& checkMetricReplace(fields, queryData.getMetrics())) {
|
||||
//replace metrics
|
||||
log.info("llm begin replace metrics!");
|
||||
SchemaElement metricToReplace = queryData.getMetrics().iterator().next();
|
||||
replaceMetrics(parseInfo, metricToReplace);
|
||||
} else if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())) {
|
||||
log.info("llm begin revise filters!");
|
||||
String correctorSql = reviseCorrectS2SQL(queryData, parseInfo);
|
||||
parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql);
|
||||
semanticQuery.setParseInfo(parseInfo);
|
||||
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
|
||||
SemanticTranslateResp explain = semanticLayerService.translate(semanticQueryReq, user);
|
||||
parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL());
|
||||
} else {
|
||||
log.info("rule begin replace metrics and revise filters!");
|
||||
//remove unvalid filters
|
||||
validFilter(semanticQuery.getParseInfo().getDimensionFilters());
|
||||
validFilter(semanticQuery.getParseInfo().getMetricFilters());
|
||||
//init s2sql
|
||||
semanticQuery.initS2Sql(semanticSchema, user);
|
||||
QueryNLReq queryNLReq = new QueryNLReq();
|
||||
queryNLReq.setQueryFilters(new QueryFilters());
|
||||
queryNLReq.setUser(user);
|
||||
}
|
||||
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
|
||||
QueryResult queryResult = doExecution(semanticQueryReq, semanticQuery.getParseInfo(), user);
|
||||
queryResult.setChatContext(semanticQuery.getParseInfo());
|
||||
DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(parseInfo.getDataSetId());
|
||||
SemanticLayerService semanticService = ContextUtils.getBean(SemanticLayerService.class);
|
||||
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, user);
|
||||
queryResult.setEntityInfo(entityInfo);
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
private QueryResult doExecution(SemanticQueryReq semanticQueryReq,
|
||||
SemanticParseInfo parseInfo, User user) throws Exception {
|
||||
SemanticQueryResp queryResp = semanticLayerService.queryByReq(semanticQueryReq, user);
|
||||
QueryResult queryResult = new QueryResult();
|
||||
if (queryResp != null) {
|
||||
queryResult.setQueryAuthorization(queryResp.getQueryAuthorization());
|
||||
}
|
||||
|
||||
String sql = queryResp == null ? null : queryResp.getSql();
|
||||
List<Map<String, Object>> resultList = queryResp == null ? new ArrayList<>()
|
||||
: queryResp.getResultList();
|
||||
List<QueryColumn> columns = queryResp == null ? new ArrayList<>() : queryResp.getColumns();
|
||||
queryResult.setQuerySql(sql);
|
||||
queryResult.setQueryResults(resultList);
|
||||
queryResult.setQueryColumns(columns);
|
||||
queryResult.setQueryMode(parseInfo.getQueryMode());
|
||||
queryResult.setQueryState(QueryState.SUCCESS);
|
||||
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
private boolean checkMetricReplace(List<String> oriFields, Set<SchemaElement> metrics) {
|
||||
if (CollectionUtils.isEmpty(oriFields)) {
|
||||
return false;
|
||||
}
|
||||
if (CollectionUtils.isEmpty(metrics)) {
|
||||
return false;
|
||||
}
|
||||
List<String> metricNames = metrics.stream().map(SchemaElement::getName).collect(Collectors.toList());
|
||||
return !oriFields.containsAll(metricNames);
|
||||
}
|
||||
|
||||
private String reviseCorrectS2SQL(QueryDataReq queryData, SemanticParseInfo parseInfo) {
|
||||
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
|
||||
Map<String, Map<String, String>> havingFiledNameToValueMap = new HashMap<>();
|
||||
|
||||
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
log.info("correctorSql before replacing:{}", correctorSql);
|
||||
// get where filter and having filter
|
||||
List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctorSql);
|
||||
List<FieldExpression> havingExpressionList = SqlSelectHelper.getHavingExpressions(correctorSql);
|
||||
List<Expression> addWhereConditions = new ArrayList<>();
|
||||
List<Expression> addHavingConditions = new ArrayList<>();
|
||||
Set<String> removeWhereFieldNames = new HashSet<>();
|
||||
Set<String> removeHavingFieldNames = new HashSet<>();
|
||||
// replace where filter
|
||||
updateFilters(whereExpressionList, queryData.getDimensionFilters(),
|
||||
parseInfo.getDimensionFilters(), addWhereConditions, removeWhereFieldNames);
|
||||
updateDateInfo(queryData, parseInfo, filedNameToValueMap,
|
||||
whereExpressionList, addWhereConditions, removeWhereFieldNames);
|
||||
correctorSql = SqlReplaceHelper.replaceValue(correctorSql, filedNameToValueMap);
|
||||
correctorSql = SqlRemoveHelper.removeWhereCondition(correctorSql, removeWhereFieldNames);
|
||||
// replace having filter
|
||||
updateFilters(havingExpressionList, queryData.getDimensionFilters(),
|
||||
parseInfo.getDimensionFilters(), addHavingConditions, removeHavingFieldNames);
|
||||
correctorSql = SqlReplaceHelper.replaceHavingValue(correctorSql, havingFiledNameToValueMap);
|
||||
correctorSql = SqlRemoveHelper.removeHavingCondition(correctorSql, removeHavingFieldNames);
|
||||
|
||||
correctorSql = SqlAddHelper.addWhere(correctorSql, addWhereConditions);
|
||||
correctorSql = SqlAddHelper.addHaving(correctorSql, addHavingConditions);
|
||||
log.info("correctorSql after replacing:{}", correctorSql);
|
||||
return correctorSql;
|
||||
}
|
||||
|
||||
private void replaceMetrics(SemanticParseInfo parseInfo, SchemaElement metric) {
|
||||
List<String> oriMetrics = parseInfo.getMetrics().stream()
|
||||
.map(SchemaElement::getName).collect(Collectors.toList());
|
||||
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
log.info("before replaceMetrics:{}", correctorSql);
|
||||
log.info("filteredMetrics:{},metrics:{}", oriMetrics, metric);
|
||||
Map<String, Pair<String, String>> fieldMap = new HashMap<>();
|
||||
if (CollectionUtils.isNotEmpty(oriMetrics) && !oriMetrics.contains(metric.getName())) {
|
||||
fieldMap.put(oriMetrics.get(0), Pair.of(metric.getName(), metric.getDefaultAgg()));
|
||||
correctorSql = SqlReplaceHelper.replaceAggFields(correctorSql, fieldMap);
|
||||
}
|
||||
log.info("after replaceMetrics:{}", correctorSql);
|
||||
parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql);
|
||||
}
|
||||
|
||||
private void updateDateInfo(QueryDataReq queryData, SemanticParseInfo parseInfo,
|
||||
Map<String, Map<String, String>> filedNameToValueMap,
|
||||
List<FieldExpression> fieldExpressionList,
|
||||
List<Expression> addConditions,
|
||||
Set<String> removeFieldNames) {
|
||||
if (Objects.isNull(queryData.getDateInfo())) {
|
||||
return;
|
||||
}
|
||||
if (queryData.getDateInfo().getUnit() > 1) {
|
||||
queryData.getDateInfo().setStartDate(DateUtils.getBeforeDate(queryData.getDateInfo().getUnit() + 1));
|
||||
queryData.getDateInfo().setEndDate(DateUtils.getBeforeDate(1));
|
||||
}
|
||||
// startDate equals to endDate
|
||||
for (FieldExpression fieldExpression : fieldExpressionList) {
|
||||
if (TimeDimensionEnum.DAY.getChName().equals(fieldExpression.getFieldName())) {
|
||||
// first remove,then add
|
||||
removeFieldNames.add(TimeDimensionEnum.DAY.getChName());
|
||||
GreaterThanEquals greaterThanEquals = new GreaterThanEquals();
|
||||
addTimeFilters(queryData.getDateInfo().getStartDate(), greaterThanEquals, addConditions);
|
||||
MinorThanEquals minorThanEquals = new MinorThanEquals();
|
||||
addTimeFilters(queryData.getDateInfo().getEndDate(), minorThanEquals, addConditions);
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (FieldExpression fieldExpression : fieldExpressionList) {
|
||||
for (QueryFilter queryFilter : queryData.getDimensionFilters()) {
|
||||
if (queryFilter.getOperator().equals(FilterOperatorEnum.LIKE)
|
||||
&& FilterOperatorEnum.LIKE.getValue().toLowerCase().equals(
|
||||
fieldExpression.getOperator().toLowerCase())) {
|
||||
Map<String, String> replaceMap = new HashMap<>();
|
||||
String preValue = fieldExpression.getFieldValue().toString();
|
||||
String curValue = queryFilter.getValue().toString();
|
||||
if (preValue.startsWith("%")) {
|
||||
curValue = "%" + curValue;
|
||||
}
|
||||
if (preValue.endsWith("%")) {
|
||||
curValue = curValue + "%";
|
||||
}
|
||||
replaceMap.put(preValue, curValue);
|
||||
filedNameToValueMap.put(fieldExpression.getFieldName(), replaceMap);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
parseInfo.setDateInfo(queryData.getDateInfo());
|
||||
}
|
||||
|
||||
private <T extends ComparisonOperator> void addTimeFilters(String date,
|
||||
T comparisonExpression,
|
||||
List<Expression> addConditions) {
|
||||
Column column = new Column(TimeDimensionEnum.DAY.getChName());
|
||||
StringValue stringValue = new StringValue(date);
|
||||
comparisonExpression.setLeftExpression(column);
|
||||
comparisonExpression.setRightExpression(stringValue);
|
||||
addConditions.add(comparisonExpression);
|
||||
}
|
||||
|
||||
private void updateFilters(List<FieldExpression> fieldExpressionList,
|
||||
Set<QueryFilter> metricFilters,
|
||||
Set<QueryFilter> contextMetricFilters,
|
||||
List<Expression> addConditions,
|
||||
Set<String> removeFieldNames) {
|
||||
if (CollectionUtils.isEmpty(metricFilters)) {
|
||||
return;
|
||||
}
|
||||
for (QueryFilter dslQueryFilter : metricFilters) {
|
||||
for (FieldExpression fieldExpression : fieldExpressionList) {
|
||||
if (fieldExpression.getFieldName() != null
|
||||
&& fieldExpression.getFieldName().contains(dslQueryFilter.getName())) {
|
||||
removeFieldNames.add(dslQueryFilter.getName());
|
||||
if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.EQUALS)) {
|
||||
EqualsTo equalsTo = new EqualsTo();
|
||||
addWhereFilters(dslQueryFilter, equalsTo, contextMetricFilters, addConditions);
|
||||
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.GREATER_THAN_EQUALS)) {
|
||||
GreaterThanEquals greaterThanEquals = new GreaterThanEquals();
|
||||
addWhereFilters(dslQueryFilter, greaterThanEquals, contextMetricFilters, addConditions);
|
||||
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.GREATER_THAN)) {
|
||||
GreaterThan greaterThan = new GreaterThan();
|
||||
addWhereFilters(dslQueryFilter, greaterThan, contextMetricFilters, addConditions);
|
||||
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.MINOR_THAN_EQUALS)) {
|
||||
MinorThanEquals minorThanEquals = new MinorThanEquals();
|
||||
addWhereFilters(dslQueryFilter, minorThanEquals, contextMetricFilters, addConditions);
|
||||
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.MINOR_THAN)) {
|
||||
MinorThan minorThan = new MinorThan();
|
||||
addWhereFilters(dslQueryFilter, minorThan, contextMetricFilters, addConditions);
|
||||
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.IN)) {
|
||||
InExpression inExpression = new InExpression();
|
||||
addWhereInFilters(dslQueryFilter, inExpression, contextMetricFilters, addConditions);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// add in condition to sql where condition
|
||||
private void addWhereInFilters(QueryFilter dslQueryFilter,
|
||||
InExpression inExpression,
|
||||
Set<QueryFilter> contextMetricFilters,
|
||||
List<Expression> addConditions) {
|
||||
Column column = new Column(dslQueryFilter.getName());
|
||||
ParenthesedExpressionList parenthesedExpressionList = new ParenthesedExpressionList<>();
|
||||
List<String> valueList = JsonUtil.toList(
|
||||
JsonUtil.toString(dslQueryFilter.getValue()), String.class);
|
||||
if (CollectionUtils.isEmpty(valueList)) {
|
||||
return;
|
||||
}
|
||||
valueList.stream().forEach(o -> {
|
||||
StringValue stringValue = new StringValue(o);
|
||||
parenthesedExpressionList.add(stringValue);
|
||||
});
|
||||
inExpression.setLeftExpression(column);
|
||||
inExpression.setRightExpression(parenthesedExpressionList);
|
||||
addConditions.add(inExpression);
|
||||
contextMetricFilters.stream().forEach(o -> {
|
||||
if (o.getName().equals(dslQueryFilter.getName())) {
|
||||
o.setValue(dslQueryFilter.getValue());
|
||||
o.setOperator(dslQueryFilter.getOperator());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// add where filter
|
||||
private <T extends ComparisonOperator> void addWhereFilters(QueryFilter dslQueryFilter,
|
||||
T comparisonExpression,
|
||||
Set<QueryFilter> contextMetricFilters,
|
||||
List<Expression> addConditions) {
|
||||
String columnName = dslQueryFilter.getName();
|
||||
if (StringUtils.isNotBlank(dslQueryFilter.getFunction())) {
|
||||
columnName = dslQueryFilter.getFunction() + "(" + dslQueryFilter.getName() + ")";
|
||||
}
|
||||
if (Objects.isNull(dslQueryFilter.getValue())) {
|
||||
return;
|
||||
}
|
||||
Column column = new Column(columnName);
|
||||
comparisonExpression.setLeftExpression(column);
|
||||
if (StringUtils.isNumeric(dslQueryFilter.getValue().toString())) {
|
||||
LongValue longValue = new LongValue(Long.parseLong(dslQueryFilter.getValue().toString()));
|
||||
comparisonExpression.setRightExpression(longValue);
|
||||
} else {
|
||||
StringValue stringValue = new StringValue(dslQueryFilter.getValue().toString());
|
||||
comparisonExpression.setRightExpression(stringValue);
|
||||
}
|
||||
addConditions.add(comparisonExpression);
|
||||
contextMetricFilters.stream().forEach(o -> {
|
||||
if (o.getName().equals(dslQueryFilter.getName())) {
|
||||
o.setValue(dslQueryFilter.getValue());
|
||||
o.setOperator(dslQueryFilter.getOperator());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private SemanticParseInfo getSemanticParseInfo(QueryDataReq queryData) {
|
||||
SemanticParseInfo parseInfo = queryData.getParseInfo();
|
||||
if (LLMSqlQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
|
||||
return parseInfo;
|
||||
}
|
||||
if (CollectionUtils.isNotEmpty(queryData.getDimensions())) {
|
||||
parseInfo.setDimensions(queryData.getDimensions());
|
||||
}
|
||||
if (CollectionUtils.isNotEmpty(queryData.getMetrics())) {
|
||||
parseInfo.setMetrics(queryData.getMetrics());
|
||||
}
|
||||
if (CollectionUtils.isNotEmpty(queryData.getDimensionFilters())) {
|
||||
parseInfo.setDimensionFilters(queryData.getDimensionFilters());
|
||||
}
|
||||
if (CollectionUtils.isNotEmpty(queryData.getMetricFilters())) {
|
||||
parseInfo.setMetricFilters(queryData.getMetricFilters());
|
||||
}
|
||||
if (Objects.nonNull(queryData.getDateInfo())) {
|
||||
parseInfo.setDateInfo(queryData.getDateInfo());
|
||||
}
|
||||
return parseInfo;
|
||||
}
|
||||
|
||||
private void validFilter(Set<QueryFilter> filters) {
|
||||
for (QueryFilter queryFilter : filters) {
|
||||
if (Objects.isNull(queryFilter.getValue())) {
|
||||
filters.remove(queryFilter);
|
||||
}
|
||||
if (queryFilter.getOperator().equals(FilterOperatorEnum.IN) && CollectionUtils.isEmpty(
|
||||
JsonUtil.toList(JsonUtil.toString(queryFilter.getValue()), String.class))) {
|
||||
filters.remove(queryFilter);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception {
|
||||
SemanticQueryResp semanticQueryResp = new SemanticQueryResp();
|
||||
DimensionResp dimensionResp = schemaService.getDimension(dimensionValueReq.getElementID());
|
||||
Set<Long> dataSetIds = dimensionValueReq.getDataSetIds();
|
||||
dimensionValueReq.setModelId(dimensionResp.getModelId());
|
||||
List<String> dimensionValues = getDimensionValues(dimensionValueReq, dataSetIds);
|
||||
// if the search results is null,search dimensionValue from database
|
||||
if (CollectionUtils.isEmpty(dimensionValues)) {
|
||||
semanticQueryResp = queryDatabase(dimensionValueReq, user);
|
||||
return semanticQueryResp;
|
||||
}
|
||||
List<QueryColumn> columns = new ArrayList<>();
|
||||
QueryColumn queryColumn = new QueryColumn();
|
||||
queryColumn.setNameEn(dimensionValueReq.getBizName());
|
||||
queryColumn.setShowType("CATEGORY");
|
||||
queryColumn.setAuthorized(true);
|
||||
queryColumn.setType("CHAR");
|
||||
columns.add(queryColumn);
|
||||
List<Map<String, Object>> resultList = new ArrayList<>();
|
||||
dimensionValues.stream().forEach(o -> {
|
||||
Map<String, Object> map = new HashMap<>();
|
||||
map.put(dimensionValueReq.getBizName(), o);
|
||||
resultList.add(map);
|
||||
});
|
||||
semanticQueryResp.setColumns(columns);
|
||||
semanticQueryResp.setResultList(resultList);
|
||||
return semanticQueryResp;
|
||||
}
|
||||
|
||||
private List<String> getDimensionValues(DimensionValueReq dimensionValueReq, Set<Long> dataSetIds) {
|
||||
//if value is null ,then search from NATURE_TO_VALUES
|
||||
if (StringUtils.isBlank(dimensionValueReq.getValue())) {
|
||||
return SearchService.getDimensionValue(dimensionValueReq);
|
||||
}
|
||||
Map<Long, List<Long>> modelIdToDataSetIds = new HashMap<>();
|
||||
modelIdToDataSetIds.put(dimensionValueReq.getModelId(), new ArrayList<>(dataSetIds));
|
||||
//search from prefixSearch
|
||||
List<HanlpMapResult> hanlpMapResultList = knowledgeBaseService.prefixSearch(dimensionValueReq.getValue(),
|
||||
2000, modelIdToDataSetIds, dataSetIds);
|
||||
HanlpHelper.transLetterOriginal(hanlpMapResultList);
|
||||
return hanlpMapResultList.stream()
|
||||
.filter(o -> {
|
||||
for (String nature : o.getNatures()) {
|
||||
Long elementID = NatureHelper.getElementID(nature);
|
||||
if (dimensionValueReq.getElementID().equals(elementID)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
})
|
||||
.map(mapResult -> mapResult.getName())
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private SemanticQueryResp queryDatabase(DimensionValueReq dimensionValueReq, User user) {
|
||||
QueryDimValueReq queryDimValueReq = new QueryDimValueReq();
|
||||
queryDimValueReq.setValue(dimensionValueReq.getValue());
|
||||
queryDimValueReq.setModelId(dimensionValueReq.getModelId());
|
||||
queryDimValueReq.setDimensionBizName(dimensionValueReq.getBizName());
|
||||
return semanticLayerService.queryDimValue(queryDimValueReq, user);
|
||||
}
|
||||
|
||||
public void correct(QuerySqlReq querySqlReq, User user) {
|
||||
SemanticParseInfo semanticParseInfo = correctSqlReq(querySqlReq, user);
|
||||
querySqlReq.setSql(semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
|
||||
|
||||
@@ -4,6 +4,7 @@ import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.TaskStatusEnum;
|
||||
@@ -19,7 +20,8 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.SemanticType;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
@@ -27,11 +29,17 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SchemaFilterReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ItemResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticSchemaResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.HanlpMapResult;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.KnowledgeBaseService;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.MapResult;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.SearchService;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper;
|
||||
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
|
||||
import com.tencent.supersonic.headless.core.cache.QueryCache;
|
||||
import com.tencent.supersonic.headless.core.executor.QueryExecutor;
|
||||
@@ -57,6 +65,7 @@ import org.springframework.stereotype.Service;
|
||||
|
||||
import java.time.LocalDate;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -77,6 +86,7 @@ public class S2SemanticLayerService implements SemanticLayerService {
|
||||
private final SchemaService schemaService;
|
||||
private final SemanticTranslator semanticTranslator;
|
||||
private final MetricDrillDownChecker metricDrillDownChecker;
|
||||
private final KnowledgeBaseService knowledgeBaseService;
|
||||
private QueryCache queryCache = ComponentFactory.getQueryCache();
|
||||
private List<QueryExecutor> queryExecutors = ComponentFactory.getQueryExecutors();
|
||||
|
||||
@@ -88,7 +98,8 @@ public class S2SemanticLayerService implements SemanticLayerService {
|
||||
DataSetService dataSetService,
|
||||
SchemaService schemaService,
|
||||
SemanticTranslator semanticTranslator,
|
||||
MetricDrillDownChecker metricDrillDownChecker) {
|
||||
MetricDrillDownChecker metricDrillDownChecker,
|
||||
KnowledgeBaseService knowledgeBaseService) {
|
||||
this.statUtils = statUtils;
|
||||
this.queryUtils = queryUtils;
|
||||
this.queryReqConverter = queryReqConverter;
|
||||
@@ -97,6 +108,7 @@ public class S2SemanticLayerService implements SemanticLayerService {
|
||||
this.schemaService = schemaService;
|
||||
this.semanticTranslator = semanticTranslator;
|
||||
this.metricDrillDownChecker = metricDrillDownChecker;
|
||||
this.knowledgeBaseService = knowledgeBaseService;
|
||||
}
|
||||
|
||||
public DataSetSchema getDataSetSchema(Long id) {
|
||||
@@ -175,12 +187,74 @@ public class S2SemanticLayerService implements SemanticLayerService {
|
||||
}
|
||||
|
||||
@Override
|
||||
@SneakyThrows
|
||||
public SemanticQueryResp queryDimValue(QueryDimValueReq queryDimValueReq, User user) {
|
||||
QuerySqlReq querySqlReq = buildQuerySqlReq(queryDimValueReq);
|
||||
public SemanticQueryResp queryDimensionValue(DimensionValueReq dimensionValueReq, User user) {
|
||||
SemanticQueryResp semanticQueryResp = new SemanticQueryResp();
|
||||
DimensionResp dimensionResp = getDimension(dimensionValueReq);
|
||||
Set<Long> dataSetIds = dimensionValueReq.getDataSetIds();
|
||||
dimensionValueReq.setModelId(dimensionResp.getModelId());
|
||||
List<String> dimensionValues = getDimensionValuesFromDict(dimensionValueReq, dataSetIds);
|
||||
// if the search results is null,search dimensionValue from database
|
||||
if (CollectionUtils.isEmpty(dimensionValues)) {
|
||||
semanticQueryResp = getDimensionValuesFromDb(dimensionValueReq, user);
|
||||
return semanticQueryResp;
|
||||
}
|
||||
List<QueryColumn> columns = new ArrayList<>();
|
||||
QueryColumn queryColumn = new QueryColumn();
|
||||
queryColumn.setNameEn(dimensionValueReq.getBizName());
|
||||
queryColumn.setShowType(SemanticType.CATEGORY.name());
|
||||
queryColumn.setAuthorized(true);
|
||||
queryColumn.setType("CHAR");
|
||||
columns.add(queryColumn);
|
||||
List<Map<String, Object>> resultList = new ArrayList<>();
|
||||
dimensionValues.stream().forEach(o -> {
|
||||
Map<String, Object> map = new HashMap<>();
|
||||
map.put(dimensionValueReq.getBizName(), o);
|
||||
resultList.add(map);
|
||||
});
|
||||
semanticQueryResp.setColumns(columns);
|
||||
semanticQueryResp.setResultList(resultList);
|
||||
return semanticQueryResp;
|
||||
}
|
||||
|
||||
private List<String> getDimensionValuesFromDict(DimensionValueReq dimensionValueReq, Set<Long> dataSetIds) {
|
||||
//if value is null ,then search from NATURE_TO_VALUES
|
||||
if (StringUtils.isBlank(dimensionValueReq.getValue())) {
|
||||
return SearchService.getDimensionValue(dimensionValueReq);
|
||||
}
|
||||
Map<Long, List<Long>> modelIdToDataSetIds = new HashMap<>();
|
||||
modelIdToDataSetIds.put(dimensionValueReq.getModelId(), new ArrayList<>(dataSetIds));
|
||||
//search from prefixSearch
|
||||
List<HanlpMapResult> hanlpMapResultList = knowledgeBaseService.prefixSearch(dimensionValueReq.getValue(),
|
||||
2000, modelIdToDataSetIds, dataSetIds);
|
||||
HanlpHelper.transLetterOriginal(hanlpMapResultList);
|
||||
return hanlpMapResultList.stream()
|
||||
.filter(o -> {
|
||||
for (String nature : o.getNatures()) {
|
||||
Long elementID = NatureHelper.getElementID(nature);
|
||||
if (dimensionValueReq.getElementID().equals(elementID)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
})
|
||||
.map(MapResult::getName)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private SemanticQueryResp getDimensionValuesFromDb(DimensionValueReq dimensionValueReq, User user) {
|
||||
QuerySqlReq querySqlReq = buildQuerySqlReq(dimensionValueReq);
|
||||
return queryByReq(querySqlReq, user);
|
||||
}
|
||||
|
||||
private DimensionResp getDimension(DimensionValueReq dimensionValueReq) {
|
||||
DimensionResp dimensionResp = schemaService.getDimension(dimensionValueReq.getElementID());
|
||||
if (dimensionResp == null) {
|
||||
return schemaService.getDimension(dimensionValueReq.getBizName(),
|
||||
dimensionValueReq.getModelId());
|
||||
}
|
||||
return dimensionResp;
|
||||
}
|
||||
|
||||
public EntityInfo getEntityInfo(SemanticParseInfo parseInfo, DataSetSchema dataSetSchema, User user) {
|
||||
if (parseInfo != null && parseInfo.getDataSetId() != null && parseInfo.getDataSetId() > 0) {
|
||||
EntityInfo entityInfo = getEntityBasicInfo(dataSetSchema);
|
||||
@@ -291,10 +365,10 @@ public class S2SemanticLayerService implements SemanticLayerService {
|
||||
return schemaFilterReq;
|
||||
}
|
||||
|
||||
private QuerySqlReq buildQuerySqlReq(QueryDimValueReq queryDimValueReq) {
|
||||
private QuerySqlReq buildQuerySqlReq(DimensionValueReq queryDimValueReq) {
|
||||
QuerySqlReq querySqlReq = new QuerySqlReq();
|
||||
List<ModelResp> modelResps = schemaService.getModelList(Lists.newArrayList(queryDimValueReq.getModelId()));
|
||||
DimensionResp dimensionResp = schemaService.getDimension(queryDimValueReq.getDimensionBizName(),
|
||||
DimensionResp dimensionResp = schemaService.getDimension(queryDimValueReq.getBizName(),
|
||||
queryDimValueReq.getModelId());
|
||||
ModelResp modelResp = modelResps.get(0);
|
||||
String sql = String.format("select distinct %s from %s where 1=1",
|
||||
@@ -306,7 +380,7 @@ public class S2SemanticLayerService implements SemanticLayerService {
|
||||
queryDimValueReq.getDateInfo().getEndDate());
|
||||
}
|
||||
if (StringUtils.isNotBlank(queryDimValueReq.getValue())) {
|
||||
sql += " AND " + queryDimValueReq.getDimensionBizName() + " LIKE '%" + queryDimValueReq.getValue() + "%'";
|
||||
sql += " AND " + queryDimValueReq.getBizName() + " LIKE '%" + queryDimValueReq.getValue() + "%'";
|
||||
}
|
||||
querySqlReq.setModelIds(Sets.newHashSet(queryDimValueReq.getModelId()));
|
||||
querySqlReq.setSql(sql);
|
||||
|
||||
@@ -7,15 +7,15 @@ import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
import com.tencent.supersonic.common.pojo.enums.SensitiveLevelEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.DimValueMap;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.PageDimensionReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
||||
import com.tencent.supersonic.headless.server.pojo.DimensionFilter;
|
||||
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
|
||||
import com.tencent.supersonic.headless.server.web.service.DimensionService;
|
||||
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.web.bind.annotation.DeleteMapping;
|
||||
import org.springframework.web.bind.annotation.GetMapping;
|
||||
@@ -121,11 +121,11 @@ public class DimensionController {
|
||||
}
|
||||
|
||||
@PostMapping("/queryDimValue")
|
||||
public SemanticQueryResp queryDimValue(@RequestBody QueryDimValueReq queryDimValueReq,
|
||||
public SemanticQueryResp queryDimValue(@RequestBody DimensionValueReq dimensionValueReq,
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
User user = UserHolder.findUser(request, response);
|
||||
return queryService.queryDimValue(queryDimValueReq, user);
|
||||
return queryService.queryDimensionValue(dimensionValueReq, user);
|
||||
}
|
||||
|
||||
@DeleteMapping("deleteDimension/{id}")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.headless;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
@@ -10,11 +10,11 @@ public class QueryDimensionTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void testQueryDimValue() {
|
||||
QueryDimValueReq queryDimValueReq = new QueryDimValueReq();
|
||||
DimensionValueReq queryDimValueReq = new DimensionValueReq();
|
||||
queryDimValueReq.setModelId(1L);
|
||||
queryDimValueReq.setDimensionBizName("department");
|
||||
queryDimValueReq.setBizName("department");
|
||||
|
||||
SemanticQueryResp queryResp = semanticLayerService.queryDimValue(queryDimValueReq, User.getFakeUser());
|
||||
SemanticQueryResp queryResp = semanticLayerService.queryDimensionValue(queryDimValueReq, User.getFakeUser());
|
||||
Assert.assertNotNull(queryResp.getResultList());
|
||||
Assert.assertEquals(4, queryResp.getResultList().size());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user