(improvement)(chat) Optimize the code for the queryData and queryDimensionValue interfaces. (#1529)

This commit is contained in:
lexluo09
2024-08-07 20:56:13 +08:00
committed by GitHub
parent 208686de46
commit 3d1ca6ac1d
5 changed files with 199 additions and 146 deletions

View File

@@ -5,6 +5,7 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.executor.ChatQueryExecutor;
import com.tencent.supersonic.chat.server.parser.ChatQueryParser;
@@ -22,7 +23,6 @@ import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.BeanMapper;
@@ -33,13 +33,13 @@ import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.EntityInfo;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
@@ -54,11 +54,8 @@ import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
import net.sf.jsqlparser.expression.operators.relational.ParenthesedExpressionList;
import net.sf.jsqlparser.schema.Column;
@@ -71,6 +68,7 @@ import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@@ -195,32 +193,43 @@ public class ChatQueryServiceImpl implements ChatQueryService {
return executeContext;
}
//mainly used for executing after revising filters,for example:"fans_cnt>=100000"->"fans_cnt>500000",
//"style='流行'"->"style in ['流行','爱国']"
@Override
public Object queryData(ChatQueryDataReq chatQueryDataReq, User user) throws Exception {
Integer parseId = chatQueryDataReq.getParseId();
SemanticParseInfo parseInfo = chatManageService.getParseInfo(
chatQueryDataReq.getQueryId(), parseId);
parseInfo = mergeSemanticParseInfo(parseInfo, chatQueryDataReq);
SemanticParseInfo parseInfo = chatManageService.getParseInfo(chatQueryDataReq.getQueryId(), parseId);
parseInfo = mergeParseInfo(parseInfo, chatQueryDataReq);
DataSetSchema dataSetSchema = semanticLayerService.getDataSetSchema(parseInfo.getDataSetId());
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
semanticQuery.setParseInfo(parseInfo);
List<String> fields = new ArrayList<>();
if (Objects.nonNull(parseInfo.getSqlInfo())
&& StringUtils.isNotBlank(parseInfo.getSqlInfo().getCorrectedS2SQL())) {
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
fields = SqlSelectHelper.getAllSelectFields(correctorSql);
if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())) {
handleLLMQueryMode(chatQueryDataReq, semanticQuery, user);
} else {
handleRuleQueryMode(semanticQuery, dataSetSchema, user);
}
if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())
&& checkMetricReplace(fields, chatQueryDataReq.getMetrics())) {
//replace metrics
return executeQuery(semanticQuery, user, dataSetSchema);
}
private List<String> getFieldsFromSql(SemanticParseInfo parseInfo) {
SqlInfo sqlInfo = parseInfo.getSqlInfo();
if (Objects.isNull(sqlInfo) || StringUtils.isNotBlank(sqlInfo.getCorrectedS2SQL())) {
return new ArrayList<>();
}
return SqlSelectHelper.getAllSelectFields(sqlInfo.getCorrectedS2SQL());
}
private void handleLLMQueryMode(ChatQueryDataReq chatQueryDataReq,
SemanticQuery semanticQuery,
User user) throws Exception {
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
List<String> fields = getFieldsFromSql(parseInfo);
if (checkMetricReplace(fields, chatQueryDataReq.getMetrics())) {
log.info("llm begin replace metrics!");
SchemaElement metricToReplace = chatQueryDataReq.getMetrics().iterator().next();
replaceMetrics(parseInfo, metricToReplace);
} else if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())) {
} else {
log.info("llm begin revise filters!");
String correctorSql = reviseCorrectS2SQL(chatQueryDataReq, parseInfo);
parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql);
@@ -228,16 +237,24 @@ public class ChatQueryServiceImpl implements ChatQueryService {
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
SemanticTranslateResp explain = semanticLayerService.translate(semanticQueryReq, user);
parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL());
} else {
log.info("rule begin replace metrics and revise filters!");
//remove unvalid filters
validFilter(semanticQuery.getParseInfo().getDimensionFilters());
validFilter(semanticQuery.getParseInfo().getMetricFilters());
//init s2sql
semanticQuery.initS2Sql(dataSetSchema, user);
}
}
private void handleRuleQueryMode(SemanticQuery semanticQuery,
DataSetSchema dataSetSchema,
User user) {
log.info("rule begin replace metrics and revise filters!");
validFilter(semanticQuery.getParseInfo().getDimensionFilters());
validFilter(semanticQuery.getParseInfo().getMetricFilters());
semanticQuery.initS2Sql(dataSetSchema, user);
}
private QueryResult executeQuery(SemanticQuery semanticQuery,
User user,
DataSetSchema dataSetSchema) throws Exception {
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
QueryResult queryResult = doExecution(semanticQueryReq, semanticQuery.getParseInfo(), user);
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
QueryResult queryResult = doExecution(semanticQueryReq, parseInfo.getQueryMode(), user);
queryResult.setChatContext(semanticQuery.getParseInfo());
SemanticLayerService semanticService = ContextUtils.getBean(SemanticLayerService.class);
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, user);
@@ -246,10 +263,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
}
private boolean checkMetricReplace(List<String> oriFields, Set<SchemaElement> metrics) {
if (CollectionUtils.isEmpty(oriFields)) {
return false;
}
if (CollectionUtils.isEmpty(metrics)) {
if (CollectionUtils.isEmpty(oriFields) || CollectionUtils.isEmpty(metrics)) {
return false;
}
List<String> metricNames = metrics.stream().map(SchemaElement::getName).collect(Collectors.toList());
@@ -257,29 +271,30 @@ public class ChatQueryServiceImpl implements ChatQueryService {
}
private String reviseCorrectS2SQL(ChatQueryDataReq queryData, SemanticParseInfo parseInfo) {
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
Map<String, Map<String, String>> havingFiledNameToValueMap = new HashMap<>();
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
log.info("correctorSql before replacing:{}", correctorSql);
// get where filter and having filter
List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctorSql);
List<FieldExpression> havingExpressionList = SqlSelectHelper.getHavingExpressions(correctorSql);
List<Expression> addWhereConditions = new ArrayList<>();
List<Expression> addHavingConditions = new ArrayList<>();
Set<String> removeWhereFieldNames = new HashSet<>();
Set<String> removeHavingFieldNames = new HashSet<>();
// replace where filter
updateFilters(whereExpressionList, queryData.getDimensionFilters(),
parseInfo.getDimensionFilters(), addWhereConditions, removeWhereFieldNames);
updateDateInfo(queryData, parseInfo, filedNameToValueMap,
whereExpressionList, addWhereConditions, removeWhereFieldNames);
List<Expression> addWhereConditions = new ArrayList<>();
Set<String> removeWhereFieldNames = updateFilters(whereExpressionList, queryData.getDimensionFilters(),
parseInfo.getDimensionFilters(), addWhereConditions);
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
Set<String> removeDataFieldNames = updateDateInfo(queryData, parseInfo, filedNameToValueMap,
whereExpressionList, addWhereConditions);
removeWhereFieldNames.addAll(removeDataFieldNames);
correctorSql = SqlReplaceHelper.replaceValue(correctorSql, filedNameToValueMap);
correctorSql = SqlRemoveHelper.removeWhereCondition(correctorSql, removeWhereFieldNames);
// replace having filter
updateFilters(havingExpressionList, queryData.getDimensionFilters(),
parseInfo.getDimensionFilters(), addHavingConditions, removeHavingFieldNames);
correctorSql = SqlReplaceHelper.replaceHavingValue(correctorSql, havingFiledNameToValueMap);
List<FieldExpression> havingExpressionList = SqlSelectHelper.getHavingExpressions(correctorSql);
List<Expression> addHavingConditions = new ArrayList<>();
Set<String> removeHavingFieldNames = updateFilters(havingExpressionList,
queryData.getDimensionFilters(), parseInfo.getDimensionFilters(), addHavingConditions);
correctorSql = SqlReplaceHelper.replaceHavingValue(correctorSql, new HashMap<>());
correctorSql = SqlRemoveHelper.removeHavingCondition(correctorSql, removeHavingFieldNames);
correctorSql = SqlAddHelper.addWhere(correctorSql, addWhereConditions);
@@ -303,34 +318,32 @@ public class ChatQueryServiceImpl implements ChatQueryService {
parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql);
}
private QueryResult doExecution(SemanticQueryReq semanticQueryReq,
SemanticParseInfo parseInfo, User user) throws Exception {
private QueryResult doExecution(SemanticQueryReq semanticQueryReq, String queryMode, User user) throws Exception {
SemanticQueryResp queryResp = semanticLayerService.queryByReq(semanticQueryReq, user);
QueryResult queryResult = new QueryResult();
if (queryResp != null) {
queryResult.setQueryAuthorization(queryResp.getQueryAuthorization());
queryResult.setQuerySql(queryResp.getSql());
queryResult.setQueryResults(queryResp.getResultList());
queryResult.setQueryColumns(queryResp.getColumns());
} else {
queryResult.setQueryResults(new ArrayList<>());
queryResult.setQueryColumns(new ArrayList<>());
}
String sql = queryResp == null ? null : queryResp.getSql();
List<Map<String, Object>> resultList = queryResp == null ? new ArrayList<>()
: queryResp.getResultList();
List<QueryColumn> columns = queryResp == null ? new ArrayList<>() : queryResp.getColumns();
queryResult.setQuerySql(sql);
queryResult.setQueryResults(resultList);
queryResult.setQueryColumns(columns);
queryResult.setQueryMode(parseInfo.getQueryMode());
queryResult.setQueryMode(queryMode);
queryResult.setQueryState(QueryState.SUCCESS);
return queryResult;
}
private void updateDateInfo(ChatQueryDataReq queryData, SemanticParseInfo parseInfo,
Map<String, Map<String, String>> filedNameToValueMap,
List<FieldExpression> fieldExpressionList,
List<Expression> addConditions,
Set<String> removeFieldNames) {
private Set<String> updateDateInfo(ChatQueryDataReq queryData, SemanticParseInfo parseInfo,
Map<String, Map<String, String>> filedNameToValueMap,
List<FieldExpression> fieldExpressionList,
List<Expression> addConditions) {
Set<String> removeFieldNames = new HashSet<>();
if (Objects.isNull(queryData.getDateInfo())) {
return;
return removeFieldNames;
}
if (queryData.getDateInfo().getUnit() > 1) {
queryData.getDateInfo().setStartDate(DateUtils.getBeforeDate(queryData.getDateInfo().getUnit() + 1));
@@ -369,6 +382,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
}
}
parseInfo.setDateInfo(queryData.getDateInfo());
return removeFieldNames;
}
private <T extends ComparisonOperator> void addTimeFilters(String date,
@@ -381,42 +395,41 @@ public class ChatQueryServiceImpl implements ChatQueryService {
addConditions.add(comparisonExpression);
}
private void updateFilters(List<FieldExpression> fieldExpressionList,
Set<QueryFilter> metricFilters,
Set<QueryFilter> contextMetricFilters,
List<Expression> addConditions,
Set<String> removeFieldNames) {
if (org.apache.commons.collections.CollectionUtils.isEmpty(metricFilters)) {
return;
private Set<String> updateFilters(List<FieldExpression> fieldExpressionList,
Set<QueryFilter> metricFilters,
Set<QueryFilter> contextMetricFilters,
List<Expression> addConditions) {
Set<String> removeFieldNames = new HashSet<>();
if (CollectionUtils.isEmpty(metricFilters)) {
return removeFieldNames;
}
for (QueryFilter dslQueryFilter : metricFilters) {
for (FieldExpression fieldExpression : fieldExpressionList) {
if (fieldExpression.getFieldName() != null
&& fieldExpression.getFieldName().contains(dslQueryFilter.getName())) {
removeFieldNames.add(dslQueryFilter.getName());
if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.EQUALS)) {
EqualsTo equalsTo = new EqualsTo();
addWhereFilters(dslQueryFilter, equalsTo, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.GREATER_THAN_EQUALS)) {
GreaterThanEquals greaterThanEquals = new GreaterThanEquals();
addWhereFilters(dslQueryFilter, greaterThanEquals, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.GREATER_THAN)) {
GreaterThan greaterThan = new GreaterThan();
addWhereFilters(dslQueryFilter, greaterThan, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.MINOR_THAN_EQUALS)) {
MinorThanEquals minorThanEquals = new MinorThanEquals();
addWhereFilters(dslQueryFilter, minorThanEquals, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.MINOR_THAN)) {
MinorThan minorThan = new MinorThan();
addWhereFilters(dslQueryFilter, minorThan, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.IN)) {
InExpression inExpression = new InExpression();
addWhereInFilters(dslQueryFilter, inExpression, contextMetricFilters, addConditions);
}
handleFilter(dslQueryFilter, contextMetricFilters, addConditions);
break;
}
}
}
return removeFieldNames;
}
private void handleFilter(QueryFilter dslQueryFilter,
Set<QueryFilter> contextMetricFilters,
List<Expression> addConditions) {
FilterOperatorEnum operator = dslQueryFilter.getOperator();
if (operator == FilterOperatorEnum.IN) {
addWhereInFilters(dslQueryFilter, new InExpression(), contextMetricFilters, addConditions);
} else {
ComparisonOperator expression = FilterOperatorEnum.createExpression(operator);
if (Objects.nonNull(expression)) {
addWhereFilters(dslQueryFilter, expression, contextMetricFilters, addConditions);
}
}
}
// add in condition to sql where condition
@@ -428,7 +441,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
ParenthesedExpressionList parenthesedExpressionList = new ParenthesedExpressionList<>();
List<String> valueList = JsonUtil.toList(
JsonUtil.toString(dslQueryFilter.getValue()), String.class);
if (org.apache.commons.collections.CollectionUtils.isEmpty(valueList)) {
if (CollectionUtils.isEmpty(valueList)) {
return;
}
valueList.stream().forEach(o -> {
@@ -447,10 +460,10 @@ public class ChatQueryServiceImpl implements ChatQueryService {
}
// add where filter
private <T extends ComparisonOperator> void addWhereFilters(QueryFilter dslQueryFilter,
T comparisonExpression,
Set<QueryFilter> contextMetricFilters,
List<Expression> addConditions) {
private void addWhereFilters(QueryFilter dslQueryFilter,
ComparisonOperator comparisonExpression,
Set<QueryFilter> contextMetricFilters,
List<Expression> addConditions) {
String columnName = dslQueryFilter.getName();
if (StringUtils.isNotBlank(dslQueryFilter.getFunction())) {
columnName = dslQueryFilter.getFunction() + "(" + dslQueryFilter.getName() + ")";
@@ -476,8 +489,8 @@ public class ChatQueryServiceImpl implements ChatQueryService {
});
}
private SemanticParseInfo mergeSemanticParseInfo(SemanticParseInfo parseInfo,
ChatQueryDataReq queryData) {
private SemanticParseInfo mergeParseInfo(SemanticParseInfo parseInfo,
ChatQueryDataReq queryData) {
if (LLMSqlQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
return parseInfo;
}
@@ -500,13 +513,18 @@ public class ChatQueryServiceImpl implements ChatQueryService {
}
private void validFilter(Set<QueryFilter> filters) {
for (QueryFilter queryFilter : filters) {
if (Objects.isNull(queryFilter.getValue())) {
filters.remove(queryFilter);
Iterator<QueryFilter> iterator = filters.iterator();
while (iterator.hasNext()) {
QueryFilter queryFilter = iterator.next();
Object queryFilterValue = queryFilter.getValue();
if (Objects.isNull(queryFilterValue)) {
iterator.remove();
continue;
}
if (queryFilter.getOperator().equals(FilterOperatorEnum.IN) && CollectionUtils.isEmpty(
JsonUtil.toList(JsonUtil.toString(queryFilter.getValue()), String.class))) {
filters.remove(queryFilter);
List<String> collection = JsonUtil.toList(JsonUtil.toString(queryFilterValue), String.class);
if (FilterOperatorEnum.IN.equals(queryFilter.getOperator())
&& CollectionUtils.isEmpty(collection)) {
iterator.remove();
}
}
}