mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(chat) Switching metric supports default aggregation method of metric (#534)
Co-authored-by: jolunoluo
This commit is contained in:
@@ -4,10 +4,8 @@ package com.tencent.supersonic.chat.rest;
|
|||||||
import com.github.pagehelper.PageInfo;
|
import com.github.pagehelper.PageInfo;
|
||||||
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.QueryRecallResp;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
|
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
|
|
||||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
|
import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
|
||||||
import com.tencent.supersonic.chat.service.ChatService;
|
import com.tencent.supersonic.chat.service.ChatService;
|
||||||
import org.springframework.web.bind.annotation.GetMapping;
|
import org.springframework.web.bind.annotation.GetMapping;
|
||||||
@@ -93,15 +91,4 @@ public class ChatController {
|
|||||||
return chatService.queryShowCase(pageQueryInfoCommand, agentId);
|
return chatService.queryShowCase(pageQueryInfoCommand, agentId);
|
||||||
}
|
}
|
||||||
|
|
||||||
@RequestMapping("/getSolvedQuery")
|
|
||||||
public List<SimilarQueryRecallResp> getSolvedQuery(@RequestParam(value = "queryText") String queryText,
|
|
||||||
@RequestParam(value = "agentId") Integer agentId) {
|
|
||||||
QueryRecallResp queryRecallResp = new QueryRecallResp();
|
|
||||||
Long startTime = System.currentTimeMillis();
|
|
||||||
List<SimilarQueryRecallResp> solvedQueryRecallRespList = chatService.getSolvedQuery(queryText, agentId);
|
|
||||||
queryRecallResp.setSolvedQueryRecallRespList(solvedQueryRecallRespList);
|
|
||||||
queryRecallResp.setQueryTimeCost(System.currentTimeMillis() - startTime);
|
|
||||||
return solvedQueryRecallRespList;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
|||||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
|
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
|
|
||||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
|
import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
|
||||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
|
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
|
||||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
|
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
|
||||||
@@ -58,6 +57,4 @@ public interface ChatService {
|
|||||||
ChatParseDO getParseInfo(Long questionId, int parseId);
|
ChatParseDO getParseInfo(Long questionId, int parseId);
|
||||||
|
|
||||||
Boolean deleteChatQuery(Long questionId);
|
Boolean deleteChatQuery(Long questionId);
|
||||||
|
|
||||||
List<SimilarQueryRecallResp> getSolvedQuery(String queryText, Integer agentId);
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
|||||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
|
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
|
|
||||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
|
import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
|
||||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
|
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
|
||||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
|
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
|
||||||
@@ -22,7 +21,6 @@ import com.tencent.supersonic.chat.service.ChatService;
|
|||||||
import com.tencent.supersonic.chat.utils.SimilarQueryManager;
|
import com.tencent.supersonic.chat.utils.SimilarQueryManager;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.compress.utils.Lists;
|
|
||||||
import org.springframework.context.annotation.Primary;
|
import org.springframework.context.annotation.Primary;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
@@ -239,32 +237,4 @@ public class ChatServiceImpl implements ChatService {
|
|||||||
return chatQueryRepository.deleteChatQuery(questionId);
|
return chatQueryRepository.deleteChatQuery(questionId);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SimilarQueryRecallResp> getSolvedQuery(String queryText, Integer agentId) {
|
|
||||||
//1. recall solved query by queryText
|
|
||||||
List<SimilarQueryRecallResp> solvedQueryRecallResps = solvedQueryManager.recallSimilarQuery(queryText, agentId);
|
|
||||||
if (CollectionUtils.isEmpty(solvedQueryRecallResps)) {
|
|
||||||
return Lists.newArrayList();
|
|
||||||
}
|
|
||||||
List<Long> queryIds = solvedQueryRecallResps.stream()
|
|
||||||
.map(SimilarQueryRecallResp::getQueryId).collect(Collectors.toList());
|
|
||||||
PageQueryInfoReq pageQueryInfoReq = new PageQueryInfoReq();
|
|
||||||
pageQueryInfoReq.setIds(queryIds);
|
|
||||||
pageQueryInfoReq.setPageSize(100);
|
|
||||||
pageQueryInfoReq.setCurrent(1);
|
|
||||||
//2. remove low score query
|
|
||||||
int lowScoreThreshold = 3;
|
|
||||||
PageInfo<QueryResp> queryRespPageInfo = chatQueryRepository.getChatQuery(pageQueryInfoReq, null);
|
|
||||||
List<QueryResp> queryResps = queryRespPageInfo.getList();
|
|
||||||
if (CollectionUtils.isEmpty(queryResps)) {
|
|
||||||
return Lists.newArrayList();
|
|
||||||
}
|
|
||||||
Set<Long> lowScoreQueryIds = queryResps.stream().filter(queryResp ->
|
|
||||||
queryResp.getScore() != null && queryResp.getScore() <= lowScoreThreshold)
|
|
||||||
.map(QueryResp::getQuestionId).collect(Collectors.toSet());
|
|
||||||
return solvedQueryRecallResps.stream().filter(solvedQueryRecallResp ->
|
|
||||||
!lowScoreQueryIds.contains(solvedQueryRecallResp.getQueryId()))
|
|
||||||
.collect(Collectors.toList());
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,11 +28,11 @@ import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
|
|||||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
|
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
|
||||||
import com.tencent.supersonic.chat.persistence.dataobject.CostType;
|
import com.tencent.supersonic.chat.persistence.dataobject.CostType;
|
||||||
import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO;
|
import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO;
|
||||||
|
import com.tencent.supersonic.chat.processor.execute.ExecuteResultProcessor;
|
||||||
import com.tencent.supersonic.chat.processor.parse.ParseResultProcessor;
|
import com.tencent.supersonic.chat.processor.parse.ParseResultProcessor;
|
||||||
import com.tencent.supersonic.chat.query.QueryManager;
|
import com.tencent.supersonic.chat.query.QueryManager;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlQuery;
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlQuery;
|
||||||
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
||||||
import com.tencent.supersonic.chat.processor.execute.ExecuteResultProcessor;
|
|
||||||
import com.tencent.supersonic.chat.service.ChatService;
|
import com.tencent.supersonic.chat.service.ChatService;
|
||||||
import com.tencent.supersonic.chat.service.QueryService;
|
import com.tencent.supersonic.chat.service.QueryService;
|
||||||
import com.tencent.supersonic.chat.service.SemanticService;
|
import com.tencent.supersonic.chat.service.SemanticService;
|
||||||
@@ -42,9 +42,9 @@ import com.tencent.supersonic.chat.utils.ComponentFactory;
|
|||||||
import com.tencent.supersonic.chat.utils.SimilarQueryManager;
|
import com.tencent.supersonic.chat.utils.SimilarQueryManager;
|
||||||
import com.tencent.supersonic.common.pojo.DateConf;
|
import com.tencent.supersonic.common.pojo.DateConf;
|
||||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
|
||||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.DateUtils;
|
import com.tencent.supersonic.common.util.DateUtils;
|
||||||
@@ -77,6 +77,7 @@ import net.sf.jsqlparser.schema.Column;
|
|||||||
import org.apache.calcite.sql.parser.SqlParseException;
|
import org.apache.calcite.sql.parser.SqlParseException;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
import org.springframework.context.annotation.Primary;
|
import org.springframework.context.annotation.Primary;
|
||||||
@@ -264,19 +265,19 @@ public class QueryServiceImpl implements QueryService {
|
|||||||
|
|
||||||
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
|
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
|
||||||
semanticQuery.setParseInfo(parseInfo);
|
semanticQuery.setParseInfo(parseInfo);
|
||||||
List<String> metrics = queryData.getMetrics().stream().map(o -> o.getName()).collect(Collectors.toList());
|
|
||||||
List<String> fields = new ArrayList<>();
|
List<String> fields = new ArrayList<>();
|
||||||
if (Objects.nonNull(parseInfo.getSqlInfo())
|
if (Objects.nonNull(parseInfo.getSqlInfo())
|
||||||
&& StringUtils.isNotBlank(parseInfo.getSqlInfo().getCorrectS2SQL())) {
|
&& StringUtils.isNotBlank(parseInfo.getSqlInfo().getCorrectS2SQL())) {
|
||||||
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL();
|
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
fields = SqlParserSelectHelper.getAllFields(correctorSql);
|
fields = SqlParserSelectHelper.getAllFields(correctorSql);
|
||||||
}
|
}
|
||||||
if (CollectionUtils.isNotEmpty(fields) && !fields.containsAll(metrics)
|
if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())
|
||||||
&& CollectionUtils.isNotEmpty(queryData.getMetrics())) {
|
&& checkMetricReplace(fields, queryData.getMetrics())) {
|
||||||
//replace metrics
|
//replace metrics
|
||||||
log.info("llm begin replace metrics!");
|
log.info("llm begin replace metrics!");
|
||||||
replaceMetrics(parseInfo, metrics);
|
SchemaElement metricToReplace = queryData.getMetrics().iterator().next();
|
||||||
} else if (LLMSqlQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
|
replaceMetrics(parseInfo, metricToReplace);
|
||||||
|
} else if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())) {
|
||||||
log.info("llm begin revise filters!");
|
log.info("llm begin revise filters!");
|
||||||
String correctorSql = reviseCorrectS2SQL(queryData, parseInfo);
|
String correctorSql = reviseCorrectS2SQL(queryData, parseInfo);
|
||||||
parseInfo.getSqlInfo().setCorrectS2SQL(correctorSql);
|
parseInfo.getSqlInfo().setCorrectS2SQL(correctorSql);
|
||||||
@@ -304,6 +305,17 @@ public class QueryServiceImpl implements QueryService {
|
|||||||
return queryResult;
|
return queryResult;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private boolean checkMetricReplace(List<String> oriFields, Set<SchemaElement> metrics) {
|
||||||
|
if (CollectionUtils.isEmpty(oriFields)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (CollectionUtils.isEmpty(metrics)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
List<String> metricNames = metrics.stream().map(SchemaElement::getName).collect(Collectors.toList());
|
||||||
|
return !oriFields.containsAll(metricNames);
|
||||||
|
}
|
||||||
|
|
||||||
public String reviseCorrectS2SQL(QueryDataReq queryData, SemanticParseInfo parseInfo) {
|
public String reviseCorrectS2SQL(QueryDataReq queryData, SemanticParseInfo parseInfo) {
|
||||||
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
|
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
|
||||||
Map<String, Map<String, String>> havingFiledNameToValueMap = new HashMap<>();
|
Map<String, Map<String, String>> havingFiledNameToValueMap = new HashMap<>();
|
||||||
@@ -336,16 +348,16 @@ public class QueryServiceImpl implements QueryService {
|
|||||||
return correctorSql;
|
return correctorSql;
|
||||||
}
|
}
|
||||||
|
|
||||||
private void replaceMetrics(SemanticParseInfo parseInfo, List<String> metrics) {
|
private void replaceMetrics(SemanticParseInfo parseInfo, SchemaElement metric) {
|
||||||
List<String> filteredMetrics = parseInfo.getMetrics().stream()
|
List<String> oriMetrics = parseInfo.getMetrics().stream()
|
||||||
.map(o -> o.getName()).collect(Collectors.toList());
|
.map(SchemaElement::getName).collect(Collectors.toList());
|
||||||
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL();
|
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
log.info("before replaceMetrics:{}", correctorSql);
|
log.info("before replaceMetrics:{}", correctorSql);
|
||||||
log.info("filteredMetrics:{},metrics:{}", filteredMetrics, metrics);
|
log.info("filteredMetrics:{},metrics:{}", oriMetrics, metric);
|
||||||
Map<String, String> fieldMap = new HashMap<>();
|
Map<String, Pair<String, String>> fieldMap = new HashMap<>();
|
||||||
if (CollectionUtils.isNotEmpty(filteredMetrics) && CollectionUtils.isNotEmpty(metrics)) {
|
if (CollectionUtils.isNotEmpty(oriMetrics) && !oriMetrics.contains(metric.getName())) {
|
||||||
fieldMap.put(filteredMetrics.get(0), metrics.get(0));
|
fieldMap.put(oriMetrics.get(0), Pair.of(metric.getName(), metric.getDefaultAgg()));
|
||||||
correctorSql = SqlParserReplaceHelper.replaceSelectFields(correctorSql, fieldMap);
|
correctorSql = SqlParserReplaceHelper.replaceAggFields(correctorSql, fieldMap);
|
||||||
}
|
}
|
||||||
log.info("after replaceMetrics:{}", correctorSql);
|
log.info("after replaceMetrics:{}", correctorSql);
|
||||||
parseInfo.getSqlInfo().setCorrectS2SQL(correctorSql);
|
parseInfo.getSqlInfo().setCorrectS2SQL(correctorSql);
|
||||||
@@ -541,9 +553,9 @@ public class QueryServiceImpl implements QueryService {
|
|||||||
if (CollectionUtils.isNotEmpty(queryData.getDimensions())) {
|
if (CollectionUtils.isNotEmpty(queryData.getDimensions())) {
|
||||||
parseInfo.setDimensions(queryData.getDimensions());
|
parseInfo.setDimensions(queryData.getDimensions());
|
||||||
}
|
}
|
||||||
//if (CollectionUtils.isNotEmpty(queryData.getMetrics())) {
|
if (CollectionUtils.isNotEmpty(queryData.getMetrics())) {
|
||||||
// parseInfo.setMetrics(queryData.getMetrics());
|
parseInfo.setMetrics(queryData.getMetrics());
|
||||||
//}
|
}
|
||||||
if (CollectionUtils.isNotEmpty(queryData.getDimensionFilters())) {
|
if (CollectionUtils.isNotEmpty(queryData.getDimensionFilters())) {
|
||||||
parseInfo.setDimensionFilters(queryData.getDimensionFilters());
|
parseInfo.setDimensionFilters(queryData.getDimensionFilters());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -146,7 +146,8 @@ public class QueryReqBuilder {
|
|||||||
List<Aggregator> aggregators = new ArrayList<>();
|
List<Aggregator> aggregators = new ArrayList<>();
|
||||||
if (metric != null) {
|
if (metric != null) {
|
||||||
String agg = "";
|
String agg = "";
|
||||||
if (Objects.isNull(aggregateType) || aggregateType.equals(AggregateTypeEnum.NONE)) {
|
if (Objects.isNull(aggregateType) || aggregateType.equals(AggregateTypeEnum.NONE)
|
||||||
|
|| AggOperatorEnum.COUNT_DISTINCT.name().equalsIgnoreCase(metric.getDefaultAgg())) {
|
||||||
if (StringUtils.isNotBlank(metric.getDefaultAgg())) {
|
if (StringUtils.isNotBlank(metric.getDefaultAgg())) {
|
||||||
agg = metric.getDefaultAgg();
|
agg = metric.getDefaultAgg();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ public class DataItem {
|
|||||||
|
|
||||||
private Long modelId;
|
private Long modelId;
|
||||||
|
|
||||||
|
private String defaultAgg;
|
||||||
|
|
||||||
public String getNewName() {
|
public String getNewName() {
|
||||||
return newName == null ? name : newName;
|
return newName == null ? name : newName;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,6 @@
|
|||||||
package com.tencent.supersonic.common.util.jsqlparser;
|
package com.tencent.supersonic.common.util.jsqlparser;
|
||||||
|
|
||||||
import java.util.List;
|
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.common.util.StringUtil;
|
import com.tencent.supersonic.common.util.StringUtil;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import net.sf.jsqlparser.JSQLParserException;
|
import net.sf.jsqlparser.JSQLParserException;
|
||||||
@@ -32,10 +27,16 @@ import net.sf.jsqlparser.statement.select.SelectBody;
|
|||||||
import net.sf.jsqlparser.statement.select.SelectExpressionItem;
|
import net.sf.jsqlparser.statement.select.SelectExpressionItem;
|
||||||
import net.sf.jsqlparser.statement.select.SelectItem;
|
import net.sf.jsqlparser.statement.select.SelectItem;
|
||||||
import net.sf.jsqlparser.statement.select.SelectVisitorAdapter;
|
import net.sf.jsqlparser.statement.select.SelectVisitorAdapter;
|
||||||
import net.sf.jsqlparser.statement.select.SubSelect;
|
|
||||||
import net.sf.jsqlparser.statement.select.SetOperationList;
|
import net.sf.jsqlparser.statement.select.SetOperationList;
|
||||||
|
import net.sf.jsqlparser.statement.select.SubSelect;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sql Parser replace Helper
|
* Sql Parser replace Helper
|
||||||
@@ -81,6 +82,39 @@ public class SqlParserReplaceHelper {
|
|||||||
return selectStatement.toString();
|
return selectStatement.toString();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static String replaceAggFields(String sql, Map<String, Pair<String, String>> fieldNameToAggMap) {
|
||||||
|
Select selectStatement = SqlParserSelectHelper.getSelect(sql);
|
||||||
|
SelectBody selectBody = selectStatement.getSelectBody();
|
||||||
|
if (!(selectBody instanceof PlainSelect)) {
|
||||||
|
return sql;
|
||||||
|
}
|
||||||
|
((PlainSelect) selectBody).getSelectItems().stream().forEach(o -> {
|
||||||
|
SelectExpressionItem selectExpressionItem = (SelectExpressionItem) o;
|
||||||
|
if (selectExpressionItem.getExpression() instanceof Function) {
|
||||||
|
Function function = (Function) selectExpressionItem.getExpression();
|
||||||
|
Column column = (Column) function.getParameters().getExpressions().get(0);
|
||||||
|
if (fieldNameToAggMap.containsKey(column.getColumnName())) {
|
||||||
|
Pair<String, String> agg = fieldNameToAggMap.get(column.getColumnName());
|
||||||
|
String field = agg.getKey();
|
||||||
|
String func = agg.getRight();
|
||||||
|
if (AggOperatorEnum.isCountDistinct(func)) {
|
||||||
|
function.setName("count");
|
||||||
|
function.setDistinct(true);
|
||||||
|
} else {
|
||||||
|
function.setName(func);
|
||||||
|
}
|
||||||
|
List<Expression> expressions = new ArrayList<>();
|
||||||
|
expressions.add(new Column(field));
|
||||||
|
function.getParameters().setExpressions(expressions);
|
||||||
|
if (Objects.nonNull(selectExpressionItem.getAlias()) && StringUtils.isNotBlank(field)) {
|
||||||
|
selectExpressionItem.getAlias().setName(field);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return selectStatement.toString();
|
||||||
|
}
|
||||||
|
|
||||||
public static String replaceValue(String sql, Map<String, Map<String, String>> filedNameToValueMap) {
|
public static String replaceValue(String sql, Map<String, Map<String, String>> filedNameToValueMap) {
|
||||||
return replaceValue(sql, filedNameToValueMap, true);
|
return replaceValue(sql, filedNameToValueMap, true);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,14 +1,16 @@
|
|||||||
package com.tencent.supersonic.common.util.jsqlparser;
|
package com.tencent.supersonic.common.util.jsqlparser;
|
||||||
|
|
||||||
import java.util.Set;
|
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||||
import java.util.HashSet;
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.HashMap;
|
|
||||||
|
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* SqlParserReplaceHelperTest
|
* SqlParserReplaceHelperTest
|
||||||
*/
|
*/
|
||||||
@@ -35,6 +37,18 @@ class SqlParserReplaceHelperTest {
|
|||||||
Assert.assertEquals("SELECT 维度1, 播放量1 FROM 数据库 WHERE (歌手名 = '张三') AND 数据日期 = '2023-11-17' GROUP BY 维度1", sql);
|
Assert.assertEquals("SELECT 维度1, 播放量1 FROM 数据库 WHERE (歌手名 = '张三') AND 数据日期 = '2023-11-17' GROUP BY 维度1", sql);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void replaceAggField() {
|
||||||
|
String sql = "SELECT 维度1,sum(播放量) FROM 数据库 "
|
||||||
|
+ "WHERE (歌手名 = '张三') AND 数据日期 = '2023-11-17' GROUP BY 维度1";
|
||||||
|
Map<String, Pair<String, String>> fieldMap = new HashMap<>();
|
||||||
|
fieldMap.put("播放量", Pair.of("收听用户数", AggOperatorEnum.COUNT_DISTINCT.name()));
|
||||||
|
sql = SqlParserReplaceHelper.replaceAggFields(sql, fieldMap);
|
||||||
|
System.out.println(sql);
|
||||||
|
Assert.assertEquals("SELECT 维度1, count(DISTINCT 收听用户数) FROM 数据库 "
|
||||||
|
+ "WHERE (歌手名 = '张三') AND 数据日期 = '2023-11-17' GROUP BY 维度1", sql);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void replaceValue() {
|
void replaceValue() {
|
||||||
|
|
||||||
|
|||||||
@@ -401,9 +401,12 @@ public class MetricServiceImpl implements MetricService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private DataItem getDataItem(MetricDO metricDO) {
|
private DataItem getDataItem(MetricDO metricDO) {
|
||||||
|
MetricResp metricResp = MetricConverter.convert2MetricResp(metricDO,
|
||||||
|
new HashMap<>(), Lists.newArrayList());
|
||||||
return DataItem.builder().id(metricDO.getId()).name(metricDO.getName())
|
return DataItem.builder().id(metricDO.getId()).name(metricDO.getName())
|
||||||
.bizName(metricDO.getBizName())
|
.bizName(metricDO.getBizName())
|
||||||
.modelId(metricDO.getModelId()).type(TypeEnums.METRIC).build();
|
.modelId(metricDO.getModelId()).type(TypeEnums.METRIC)
|
||||||
|
.defaultAgg(metricResp.getDefaultAgg()).build();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user