(improvement)(chat) Optimize the code for the queryData and queryDimensionValue interfaces. (#1529)

This commit is contained in:
lexluo09
2024-08-07 20:56:13 +08:00
committed by GitHub
parent 208686de46
commit 3d1ca6ac1d
5 changed files with 199 additions and 146 deletions

View File

@@ -5,6 +5,7 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq; import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq; import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.executor.ChatQueryExecutor; import com.tencent.supersonic.chat.server.executor.ChatQueryExecutor;
import com.tencent.supersonic.chat.server.parser.ChatQueryParser; import com.tencent.supersonic.chat.server.parser.ChatQueryParser;
@@ -22,7 +23,6 @@ import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper; import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper; import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.BeanMapper; import com.tencent.supersonic.common.util.BeanMapper;
@@ -33,13 +33,13 @@ import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.EntityInfo; import com.tencent.supersonic.headless.api.pojo.EntityInfo;
import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq; import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq; import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp; import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.response.QueryState; import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import com.tencent.supersonic.headless.api.pojo.response.SearchResult; import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
@@ -54,11 +54,8 @@ import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.LongValue; import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.StringValue; import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals; import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
import net.sf.jsqlparser.expression.operators.relational.InExpression; import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals; import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
import net.sf.jsqlparser.expression.operators.relational.ParenthesedExpressionList; import net.sf.jsqlparser.expression.operators.relational.ParenthesedExpressionList;
import net.sf.jsqlparser.schema.Column; import net.sf.jsqlparser.schema.Column;
@@ -71,6 +68,7 @@ import org.springframework.util.CollectionUtils;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
@@ -195,32 +193,43 @@ public class ChatQueryServiceImpl implements ChatQueryService {
return executeContext; return executeContext;
} }
//mainly used for executing after revising filters,for example:"fans_cnt>=100000"->"fans_cnt>500000",
//"style='流行'"->"style in ['流行','爱国']"
@Override @Override
public Object queryData(ChatQueryDataReq chatQueryDataReq, User user) throws Exception { public Object queryData(ChatQueryDataReq chatQueryDataReq, User user) throws Exception {
Integer parseId = chatQueryDataReq.getParseId(); Integer parseId = chatQueryDataReq.getParseId();
SemanticParseInfo parseInfo = chatManageService.getParseInfo( SemanticParseInfo parseInfo = chatManageService.getParseInfo(chatQueryDataReq.getQueryId(), parseId);
chatQueryDataReq.getQueryId(), parseId); parseInfo = mergeParseInfo(parseInfo, chatQueryDataReq);
parseInfo = mergeSemanticParseInfo(parseInfo, chatQueryDataReq);
DataSetSchema dataSetSchema = semanticLayerService.getDataSetSchema(parseInfo.getDataSetId()); DataSetSchema dataSetSchema = semanticLayerService.getDataSetSchema(parseInfo.getDataSetId());
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode()); SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
semanticQuery.setParseInfo(parseInfo); semanticQuery.setParseInfo(parseInfo);
List<String> fields = new ArrayList<>(); if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())) {
if (Objects.nonNull(parseInfo.getSqlInfo()) handleLLMQueryMode(chatQueryDataReq, semanticQuery, user);
&& StringUtils.isNotBlank(parseInfo.getSqlInfo().getCorrectedS2SQL())) { } else {
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL(); handleRuleQueryMode(semanticQuery, dataSetSchema, user);
fields = SqlSelectHelper.getAllSelectFields(correctorSql);
} }
if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())
&& checkMetricReplace(fields, chatQueryDataReq.getMetrics())) { return executeQuery(semanticQuery, user, dataSetSchema);
//replace metrics }
private List<String> getFieldsFromSql(SemanticParseInfo parseInfo) {
SqlInfo sqlInfo = parseInfo.getSqlInfo();
if (Objects.isNull(sqlInfo) || StringUtils.isNotBlank(sqlInfo.getCorrectedS2SQL())) {
return new ArrayList<>();
}
return SqlSelectHelper.getAllSelectFields(sqlInfo.getCorrectedS2SQL());
}
private void handleLLMQueryMode(ChatQueryDataReq chatQueryDataReq,
SemanticQuery semanticQuery,
User user) throws Exception {
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
List<String> fields = getFieldsFromSql(parseInfo);
if (checkMetricReplace(fields, chatQueryDataReq.getMetrics())) {
log.info("llm begin replace metrics!"); log.info("llm begin replace metrics!");
SchemaElement metricToReplace = chatQueryDataReq.getMetrics().iterator().next(); SchemaElement metricToReplace = chatQueryDataReq.getMetrics().iterator().next();
replaceMetrics(parseInfo, metricToReplace); replaceMetrics(parseInfo, metricToReplace);
} else if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())) { } else {
log.info("llm begin revise filters!"); log.info("llm begin revise filters!");
String correctorSql = reviseCorrectS2SQL(chatQueryDataReq, parseInfo); String correctorSql = reviseCorrectS2SQL(chatQueryDataReq, parseInfo);
parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql); parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql);
@@ -228,16 +237,24 @@ public class ChatQueryServiceImpl implements ChatQueryService {
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq(); SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
SemanticTranslateResp explain = semanticLayerService.translate(semanticQueryReq, user); SemanticTranslateResp explain = semanticLayerService.translate(semanticQueryReq, user);
parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL()); parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL());
} else {
log.info("rule begin replace metrics and revise filters!");
//remove unvalid filters
validFilter(semanticQuery.getParseInfo().getDimensionFilters());
validFilter(semanticQuery.getParseInfo().getMetricFilters());
//init s2sql
semanticQuery.initS2Sql(dataSetSchema, user);
} }
}
private void handleRuleQueryMode(SemanticQuery semanticQuery,
DataSetSchema dataSetSchema,
User user) {
log.info("rule begin replace metrics and revise filters!");
validFilter(semanticQuery.getParseInfo().getDimensionFilters());
validFilter(semanticQuery.getParseInfo().getMetricFilters());
semanticQuery.initS2Sql(dataSetSchema, user);
}
private QueryResult executeQuery(SemanticQuery semanticQuery,
User user,
DataSetSchema dataSetSchema) throws Exception {
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq(); SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
QueryResult queryResult = doExecution(semanticQueryReq, semanticQuery.getParseInfo(), user); SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
QueryResult queryResult = doExecution(semanticQueryReq, parseInfo.getQueryMode(), user);
queryResult.setChatContext(semanticQuery.getParseInfo()); queryResult.setChatContext(semanticQuery.getParseInfo());
SemanticLayerService semanticService = ContextUtils.getBean(SemanticLayerService.class); SemanticLayerService semanticService = ContextUtils.getBean(SemanticLayerService.class);
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, user); EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, user);
@@ -246,10 +263,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
} }
private boolean checkMetricReplace(List<String> oriFields, Set<SchemaElement> metrics) { private boolean checkMetricReplace(List<String> oriFields, Set<SchemaElement> metrics) {
if (CollectionUtils.isEmpty(oriFields)) { if (CollectionUtils.isEmpty(oriFields) || CollectionUtils.isEmpty(metrics)) {
return false;
}
if (CollectionUtils.isEmpty(metrics)) {
return false; return false;
} }
List<String> metricNames = metrics.stream().map(SchemaElement::getName).collect(Collectors.toList()); List<String> metricNames = metrics.stream().map(SchemaElement::getName).collect(Collectors.toList());
@@ -257,29 +271,30 @@ public class ChatQueryServiceImpl implements ChatQueryService {
} }
private String reviseCorrectS2SQL(ChatQueryDataReq queryData, SemanticParseInfo parseInfo) { private String reviseCorrectS2SQL(ChatQueryDataReq queryData, SemanticParseInfo parseInfo) {
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
Map<String, Map<String, String>> havingFiledNameToValueMap = new HashMap<>();
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL(); String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
log.info("correctorSql before replacing:{}", correctorSql); log.info("correctorSql before replacing:{}", correctorSql);
// get where filter and having filter // get where filter and having filter
List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctorSql); List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctorSql);
List<FieldExpression> havingExpressionList = SqlSelectHelper.getHavingExpressions(correctorSql);
List<Expression> addWhereConditions = new ArrayList<>();
List<Expression> addHavingConditions = new ArrayList<>();
Set<String> removeWhereFieldNames = new HashSet<>();
Set<String> removeHavingFieldNames = new HashSet<>();
// replace where filter // replace where filter
updateFilters(whereExpressionList, queryData.getDimensionFilters(), List<Expression> addWhereConditions = new ArrayList<>();
parseInfo.getDimensionFilters(), addWhereConditions, removeWhereFieldNames); Set<String> removeWhereFieldNames = updateFilters(whereExpressionList, queryData.getDimensionFilters(),
updateDateInfo(queryData, parseInfo, filedNameToValueMap, parseInfo.getDimensionFilters(), addWhereConditions);
whereExpressionList, addWhereConditions, removeWhereFieldNames);
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
Set<String> removeDataFieldNames = updateDateInfo(queryData, parseInfo, filedNameToValueMap,
whereExpressionList, addWhereConditions);
removeWhereFieldNames.addAll(removeDataFieldNames);
correctorSql = SqlReplaceHelper.replaceValue(correctorSql, filedNameToValueMap); correctorSql = SqlReplaceHelper.replaceValue(correctorSql, filedNameToValueMap);
correctorSql = SqlRemoveHelper.removeWhereCondition(correctorSql, removeWhereFieldNames); correctorSql = SqlRemoveHelper.removeWhereCondition(correctorSql, removeWhereFieldNames);
// replace having filter // replace having filter
updateFilters(havingExpressionList, queryData.getDimensionFilters(), List<FieldExpression> havingExpressionList = SqlSelectHelper.getHavingExpressions(correctorSql);
parseInfo.getDimensionFilters(), addHavingConditions, removeHavingFieldNames); List<Expression> addHavingConditions = new ArrayList<>();
correctorSql = SqlReplaceHelper.replaceHavingValue(correctorSql, havingFiledNameToValueMap); Set<String> removeHavingFieldNames = updateFilters(havingExpressionList,
queryData.getDimensionFilters(), parseInfo.getDimensionFilters(), addHavingConditions);
correctorSql = SqlReplaceHelper.replaceHavingValue(correctorSql, new HashMap<>());
correctorSql = SqlRemoveHelper.removeHavingCondition(correctorSql, removeHavingFieldNames); correctorSql = SqlRemoveHelper.removeHavingCondition(correctorSql, removeHavingFieldNames);
correctorSql = SqlAddHelper.addWhere(correctorSql, addWhereConditions); correctorSql = SqlAddHelper.addWhere(correctorSql, addWhereConditions);
@@ -303,34 +318,32 @@ public class ChatQueryServiceImpl implements ChatQueryService {
parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql); parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql);
} }
private QueryResult doExecution(SemanticQueryReq semanticQueryReq, private QueryResult doExecution(SemanticQueryReq semanticQueryReq, String queryMode, User user) throws Exception {
SemanticParseInfo parseInfo, User user) throws Exception {
SemanticQueryResp queryResp = semanticLayerService.queryByReq(semanticQueryReq, user); SemanticQueryResp queryResp = semanticLayerService.queryByReq(semanticQueryReq, user);
QueryResult queryResult = new QueryResult(); QueryResult queryResult = new QueryResult();
if (queryResp != null) { if (queryResp != null) {
queryResult.setQueryAuthorization(queryResp.getQueryAuthorization()); queryResult.setQueryAuthorization(queryResp.getQueryAuthorization());
queryResult.setQuerySql(queryResp.getSql());
queryResult.setQueryResults(queryResp.getResultList());
queryResult.setQueryColumns(queryResp.getColumns());
} else {
queryResult.setQueryResults(new ArrayList<>());
queryResult.setQueryColumns(new ArrayList<>());
} }
String sql = queryResp == null ? null : queryResp.getSql(); queryResult.setQueryMode(queryMode);
List<Map<String, Object>> resultList = queryResp == null ? new ArrayList<>()
: queryResp.getResultList();
List<QueryColumn> columns = queryResp == null ? new ArrayList<>() : queryResp.getColumns();
queryResult.setQuerySql(sql);
queryResult.setQueryResults(resultList);
queryResult.setQueryColumns(columns);
queryResult.setQueryMode(parseInfo.getQueryMode());
queryResult.setQueryState(QueryState.SUCCESS); queryResult.setQueryState(QueryState.SUCCESS);
return queryResult; return queryResult;
} }
private void updateDateInfo(ChatQueryDataReq queryData, SemanticParseInfo parseInfo, private Set<String> updateDateInfo(ChatQueryDataReq queryData, SemanticParseInfo parseInfo,
Map<String, Map<String, String>> filedNameToValueMap, Map<String, Map<String, String>> filedNameToValueMap,
List<FieldExpression> fieldExpressionList, List<FieldExpression> fieldExpressionList,
List<Expression> addConditions, List<Expression> addConditions) {
Set<String> removeFieldNames) { Set<String> removeFieldNames = new HashSet<>();
if (Objects.isNull(queryData.getDateInfo())) { if (Objects.isNull(queryData.getDateInfo())) {
return; return removeFieldNames;
} }
if (queryData.getDateInfo().getUnit() > 1) { if (queryData.getDateInfo().getUnit() > 1) {
queryData.getDateInfo().setStartDate(DateUtils.getBeforeDate(queryData.getDateInfo().getUnit() + 1)); queryData.getDateInfo().setStartDate(DateUtils.getBeforeDate(queryData.getDateInfo().getUnit() + 1));
@@ -369,6 +382,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
} }
} }
parseInfo.setDateInfo(queryData.getDateInfo()); parseInfo.setDateInfo(queryData.getDateInfo());
return removeFieldNames;
} }
private <T extends ComparisonOperator> void addTimeFilters(String date, private <T extends ComparisonOperator> void addTimeFilters(String date,
@@ -381,42 +395,41 @@ public class ChatQueryServiceImpl implements ChatQueryService {
addConditions.add(comparisonExpression); addConditions.add(comparisonExpression);
} }
private void updateFilters(List<FieldExpression> fieldExpressionList, private Set<String> updateFilters(List<FieldExpression> fieldExpressionList,
Set<QueryFilter> metricFilters, Set<QueryFilter> metricFilters,
Set<QueryFilter> contextMetricFilters, Set<QueryFilter> contextMetricFilters,
List<Expression> addConditions, List<Expression> addConditions) {
Set<String> removeFieldNames) { Set<String> removeFieldNames = new HashSet<>();
if (org.apache.commons.collections.CollectionUtils.isEmpty(metricFilters)) { if (CollectionUtils.isEmpty(metricFilters)) {
return; return removeFieldNames;
} }
for (QueryFilter dslQueryFilter : metricFilters) { for (QueryFilter dslQueryFilter : metricFilters) {
for (FieldExpression fieldExpression : fieldExpressionList) { for (FieldExpression fieldExpression : fieldExpressionList) {
if (fieldExpression.getFieldName() != null if (fieldExpression.getFieldName() != null
&& fieldExpression.getFieldName().contains(dslQueryFilter.getName())) { && fieldExpression.getFieldName().contains(dslQueryFilter.getName())) {
removeFieldNames.add(dslQueryFilter.getName()); removeFieldNames.add(dslQueryFilter.getName());
if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.EQUALS)) { handleFilter(dslQueryFilter, contextMetricFilters, addConditions);
EqualsTo equalsTo = new EqualsTo();
addWhereFilters(dslQueryFilter, equalsTo, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.GREATER_THAN_EQUALS)) {
GreaterThanEquals greaterThanEquals = new GreaterThanEquals();
addWhereFilters(dslQueryFilter, greaterThanEquals, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.GREATER_THAN)) {
GreaterThan greaterThan = new GreaterThan();
addWhereFilters(dslQueryFilter, greaterThan, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.MINOR_THAN_EQUALS)) {
MinorThanEquals minorThanEquals = new MinorThanEquals();
addWhereFilters(dslQueryFilter, minorThanEquals, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.MINOR_THAN)) {
MinorThan minorThan = new MinorThan();
addWhereFilters(dslQueryFilter, minorThan, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.IN)) {
InExpression inExpression = new InExpression();
addWhereInFilters(dslQueryFilter, inExpression, contextMetricFilters, addConditions);
}
break; break;
} }
} }
} }
return removeFieldNames;
}
private void handleFilter(QueryFilter dslQueryFilter,
Set<QueryFilter> contextMetricFilters,
List<Expression> addConditions) {
FilterOperatorEnum operator = dslQueryFilter.getOperator();
if (operator == FilterOperatorEnum.IN) {
addWhereInFilters(dslQueryFilter, new InExpression(), contextMetricFilters, addConditions);
} else {
ComparisonOperator expression = FilterOperatorEnum.createExpression(operator);
if (Objects.nonNull(expression)) {
addWhereFilters(dslQueryFilter, expression, contextMetricFilters, addConditions);
}
}
} }
// add in condition to sql where condition // add in condition to sql where condition
@@ -428,7 +441,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
ParenthesedExpressionList parenthesedExpressionList = new ParenthesedExpressionList<>(); ParenthesedExpressionList parenthesedExpressionList = new ParenthesedExpressionList<>();
List<String> valueList = JsonUtil.toList( List<String> valueList = JsonUtil.toList(
JsonUtil.toString(dslQueryFilter.getValue()), String.class); JsonUtil.toString(dslQueryFilter.getValue()), String.class);
if (org.apache.commons.collections.CollectionUtils.isEmpty(valueList)) { if (CollectionUtils.isEmpty(valueList)) {
return; return;
} }
valueList.stream().forEach(o -> { valueList.stream().forEach(o -> {
@@ -447,10 +460,10 @@ public class ChatQueryServiceImpl implements ChatQueryService {
} }
// add where filter // add where filter
private <T extends ComparisonOperator> void addWhereFilters(QueryFilter dslQueryFilter, private void addWhereFilters(QueryFilter dslQueryFilter,
T comparisonExpression, ComparisonOperator comparisonExpression,
Set<QueryFilter> contextMetricFilters, Set<QueryFilter> contextMetricFilters,
List<Expression> addConditions) { List<Expression> addConditions) {
String columnName = dslQueryFilter.getName(); String columnName = dslQueryFilter.getName();
if (StringUtils.isNotBlank(dslQueryFilter.getFunction())) { if (StringUtils.isNotBlank(dslQueryFilter.getFunction())) {
columnName = dslQueryFilter.getFunction() + "(" + dslQueryFilter.getName() + ")"; columnName = dslQueryFilter.getFunction() + "(" + dslQueryFilter.getName() + ")";
@@ -476,8 +489,8 @@ public class ChatQueryServiceImpl implements ChatQueryService {
}); });
} }
private SemanticParseInfo mergeSemanticParseInfo(SemanticParseInfo parseInfo, private SemanticParseInfo mergeParseInfo(SemanticParseInfo parseInfo,
ChatQueryDataReq queryData) { ChatQueryDataReq queryData) {
if (LLMSqlQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) { if (LLMSqlQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
return parseInfo; return parseInfo;
} }
@@ -500,13 +513,18 @@ public class ChatQueryServiceImpl implements ChatQueryService {
} }
private void validFilter(Set<QueryFilter> filters) { private void validFilter(Set<QueryFilter> filters) {
for (QueryFilter queryFilter : filters) { Iterator<QueryFilter> iterator = filters.iterator();
if (Objects.isNull(queryFilter.getValue())) { while (iterator.hasNext()) {
filters.remove(queryFilter); QueryFilter queryFilter = iterator.next();
Object queryFilterValue = queryFilter.getValue();
if (Objects.isNull(queryFilterValue)) {
iterator.remove();
continue;
} }
if (queryFilter.getOperator().equals(FilterOperatorEnum.IN) && CollectionUtils.isEmpty( List<String> collection = JsonUtil.toList(JsonUtil.toString(queryFilterValue), String.class);
JsonUtil.toList(JsonUtil.toString(queryFilter.getValue()), String.class))) { if (FilterOperatorEnum.IN.equals(queryFilter.getOperator())
filters.remove(queryFilter); && CollectionUtils.isEmpty(collection)) {
iterator.remove();
} }
} }
} }

View File

@@ -25,13 +25,13 @@ import java.util.Map;
import java.util.Objects; import java.util.Objects;
@Slf4j @Slf4j
public class FieldlValueReplaceVisitor extends ExpressionVisitorAdapter { public class FieldValueReplaceVisitor extends ExpressionVisitorAdapter {
ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper(); ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper();
private boolean exactReplace; private boolean exactReplace;
private Map<String, Map<String, String>> filedNameToValueMap; private Map<String, Map<String, String>> filedNameToValueMap;
public FieldlValueReplaceVisitor(boolean exactReplace, Map<String, Map<String, String>> filedNameToValueMap) { public FieldValueReplaceVisitor(boolean exactReplace, Map<String, Map<String, String>> filedNameToValueMap) {
this.exactReplace = exactReplace; this.exactReplace = exactReplace;
this.filedNameToValueMap = filedNameToValueMap; this.filedNameToValueMap = filedNameToValueMap;
} }

View File

@@ -2,15 +2,6 @@ package com.tencent.supersonic.common.jsqlparser;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.util.StringUtil; import com.tencent.supersonic.common.util.StringUtil;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.UnaryOperator;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Alias; import net.sf.jsqlparser.expression.Alias;
@@ -30,6 +21,7 @@ import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo;
import net.sf.jsqlparser.parser.CCJSqlParserUtil; import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column; import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table; import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.select.FromItem;
import net.sf.jsqlparser.statement.select.GroupByElement; import net.sf.jsqlparser.statement.select.GroupByElement;
import net.sf.jsqlparser.statement.select.Join; import net.sf.jsqlparser.statement.select.Join;
import net.sf.jsqlparser.statement.select.OrderByElement; import net.sf.jsqlparser.statement.select.OrderByElement;
@@ -40,11 +32,18 @@ import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SelectItem; import net.sf.jsqlparser.statement.select.SelectItem;
import net.sf.jsqlparser.statement.select.SelectVisitorAdapter; import net.sf.jsqlparser.statement.select.SelectVisitorAdapter;
import net.sf.jsqlparser.statement.select.SetOperationList; import net.sf.jsqlparser.statement.select.SetOperationList;
import net.sf.jsqlparser.statement.select.FromItem;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Pair;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.UnaryOperator;
/** /**
* Sql Parser replace Helper * Sql Parser replace Helper
*/ */
@@ -132,7 +131,7 @@ public class SqlReplaceHelper {
List<PlainSelect> plainSelects = SqlSelectHelper.getPlainSelect(selectStatement); List<PlainSelect> plainSelects = SqlSelectHelper.getPlainSelect(selectStatement);
for (PlainSelect plainSelect : plainSelects) { for (PlainSelect plainSelect : plainSelects) {
Expression where = plainSelect.getWhere(); Expression where = plainSelect.getWhere();
FieldlValueReplaceVisitor visitor = new FieldlValueReplaceVisitor(exactReplace, filedNameToValueMap); FieldValueReplaceVisitor visitor = new FieldValueReplaceVisitor(exactReplace, filedNameToValueMap);
if (Objects.nonNull(where)) { if (Objects.nonNull(where)) {
where.accept(visitor); where.accept(visitor);
} }
@@ -546,7 +545,7 @@ public class SqlReplaceHelper {
} }
PlainSelect plainSelect = (PlainSelect) selectStatement; PlainSelect plainSelect = (PlainSelect) selectStatement;
Expression having = plainSelect.getHaving(); Expression having = plainSelect.getHaving();
FieldlValueReplaceVisitor visitor = new FieldlValueReplaceVisitor(false, filedNameToValueMap); FieldValueReplaceVisitor visitor = new FieldValueReplaceVisitor(false, filedNameToValueMap);
if (Objects.nonNull(having)) { if (Objects.nonNull(having)) {
having.accept(visitor); having.accept(visitor);
} }

View File

@@ -3,6 +3,12 @@ package com.tencent.supersonic.common.pojo.enums;
import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonValue; import com.fasterxml.jackson.annotation.JsonValue;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
public enum FilterOperatorEnum { public enum FilterOperatorEnum {
IN("IN"), IN("IN"),
@@ -47,4 +53,20 @@ public enum FilterOperatorEnum {
|| MINOR_THAN_EQUALS.equals(filterOperatorEnum) || NOT_EQUALS.equals(filterOperatorEnum); || MINOR_THAN_EQUALS.equals(filterOperatorEnum) || NOT_EQUALS.equals(filterOperatorEnum);
} }
public static ComparisonOperator createExpression(FilterOperatorEnum operator) {
switch (operator) {
case EQUALS:
return new EqualsTo();
case GREATER_THAN_EQUALS:
return new GreaterThanEquals();
case GREATER_THAN:
return new GreaterThan();
case MINOR_THAN_EQUALS:
return new MinorThanEquals();
case MINOR_THAN:
return new MinorThan();
default:
return null;
}
}
} }

View File

@@ -202,50 +202,39 @@ public class S2SemanticLayerService implements SemanticLayerService {
DimensionResp dimensionResp = getDimension(dimensionValueReq); DimensionResp dimensionResp = getDimension(dimensionValueReq);
Set<Long> dataSetIds = dimensionValueReq.getDataSetIds(); Set<Long> dataSetIds = dimensionValueReq.getDataSetIds();
dimensionValueReq.setModelId(dimensionResp.getModelId()); dimensionValueReq.setModelId(dimensionResp.getModelId());
List<String> dimensionValues = getDimensionValuesFromDict(dimensionValueReq, dataSetIds); List<String> dimensionValues = getDimensionValuesFromDict(dimensionValueReq, dataSetIds);
// if the search results is null,search dimensionValue from database
// If the search results are null, search dimensionValue from the database
if (CollectionUtils.isEmpty(dimensionValues)) { if (CollectionUtils.isEmpty(dimensionValues)) {
return getDimensionValuesFromDb(dimensionValueReq, user); return getDimensionValuesFromDb(dimensionValueReq, user);
} }
List<QueryColumn> columns = new ArrayList<>();
QueryColumn queryColumn = new QueryColumn(); List<QueryColumn> columns = createQueryColumns(dimensionValueReq);
queryColumn.setNameEn(dimensionValueReq.getBizName()); List<Map<String, Object>> resultList = createResultList(dimensionValueReq, dimensionValues);
queryColumn.setShowType(SemanticType.CATEGORY.name());
queryColumn.setAuthorized(true);
queryColumn.setType("CHAR");
columns.add(queryColumn);
List<Map<String, Object>> resultList = new ArrayList<>();
dimensionValues.stream().forEach(o -> {
Map<String, Object> map = new HashMap<>();
map.put(dimensionValueReq.getBizName(), o);
resultList.add(map);
});
semanticQueryResp.setColumns(columns); semanticQueryResp.setColumns(columns);
semanticQueryResp.setResultList(resultList); semanticQueryResp.setResultList(resultList);
return semanticQueryResp; return semanticQueryResp;
} }
private List<String> getDimensionValuesFromDict(DimensionValueReq dimensionValueReq, Set<Long> dataSetIds) { private List<String> getDimensionValuesFromDict(DimensionValueReq dimensionValueReq, Set<Long> dataSetIds) {
//if value is null ,then search from NATURE_TO_VALUES
if (StringUtils.isBlank(dimensionValueReq.getValue())) { if (StringUtils.isBlank(dimensionValueReq.getValue())) {
return SearchService.getDimensionValue(dimensionValueReq); return SearchService.getDimensionValue(dimensionValueReq);
} }
Map<Long, List<Long>> modelIdToDataSetIds = new HashMap<>(); Map<Long, List<Long>> modelIdToDataSetIds = new HashMap<>();
modelIdToDataSetIds.put(dimensionValueReq.getModelId(), new ArrayList<>(dataSetIds)); modelIdToDataSetIds.put(dimensionValueReq.getModelId(), new ArrayList<>(dataSetIds));
//search from prefixSearch
List<HanlpMapResult> hanlpMapResultList = knowledgeBaseService.prefixSearch(dimensionValueReq.getValue(), List<HanlpMapResult> hanlpMapResultList = knowledgeBaseService.prefixSearch(
2000, modelIdToDataSetIds, dataSetIds); dimensionValueReq.getValue(), 2000, modelIdToDataSetIds, dataSetIds);
HanlpHelper.transLetterOriginal(hanlpMapResultList); HanlpHelper.transLetterOriginal(hanlpMapResultList);
return hanlpMapResultList.stream() return hanlpMapResultList.stream()
.filter(o -> { .filter(o -> o.getNatures().stream()
for (String nature : o.getNatures()) { .map(NatureHelper::getElementID)
Long elementID = NatureHelper.getElementID(nature); .anyMatch(elementID -> dimensionValueReq.getElementID().equals(elementID)))
if (dimensionValueReq.getElementID().equals(elementID)) {
return true;
}
}
return false;
})
.map(MapResult::getName) .map(MapResult::getName)
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
@@ -255,11 +244,36 @@ public class S2SemanticLayerService implements SemanticLayerService {
return queryByReq(querySqlReq, user); return queryByReq(querySqlReq, user);
} }
private List<QueryColumn> createQueryColumns(DimensionValueReq dimensionValueReq) {
QueryColumn queryColumn = new QueryColumn();
queryColumn.setNameEn(dimensionValueReq.getBizName());
queryColumn.setShowType(SemanticType.CATEGORY.name());
queryColumn.setAuthorized(true);
queryColumn.setType("CHAR");
List<QueryColumn> columns = new ArrayList<>();
columns.add(queryColumn);
return columns;
}
private List<Map<String, Object>> createResultList(DimensionValueReq dimensionValueReq,
List<String> dimensionValues) {
return dimensionValues.stream()
.map(value -> {
Map<String, Object> map = new HashMap<>();
map.put(dimensionValueReq.getBizName(), value);
return map;
})
.collect(Collectors.toList());
}
private DimensionResp getDimension(DimensionValueReq dimensionValueReq) { private DimensionResp getDimension(DimensionValueReq dimensionValueReq) {
DimensionResp dimensionResp = schemaService.getDimension(dimensionValueReq.getElementID()); Long elementID = dimensionValueReq.getElementID();
DimensionResp dimensionResp = schemaService.getDimension(elementID);
if (dimensionResp == null) { if (dimensionResp == null) {
return schemaService.getDimension(dimensionValueReq.getBizName(), String bizName = dimensionValueReq.getBizName();
dimensionValueReq.getModelId()); Long modelId = dimensionValueReq.getModelId();
return schemaService.getDimension(bizName, modelId);
} }
return dimensionResp; return dimensionResp;
} }