diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatController.java b/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatController.java index 0b68d8f2d..7a50e24a0 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatController.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatController.java @@ -4,10 +4,8 @@ package com.tencent.supersonic.chat.rest; import com.github.pagehelper.PageInfo; import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq; -import com.tencent.supersonic.chat.api.pojo.response.QueryRecallResp; import com.tencent.supersonic.chat.api.pojo.response.QueryResp; import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp; -import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp; import com.tencent.supersonic.chat.persistence.dataobject.ChatDO; import com.tencent.supersonic.chat.service.ChatService; import org.springframework.web.bind.annotation.GetMapping; @@ -93,15 +91,4 @@ public class ChatController { return chatService.queryShowCase(pageQueryInfoCommand, agentId); } - @RequestMapping("/getSolvedQuery") - public List getSolvedQuery(@RequestParam(value = "queryText") String queryText, - @RequestParam(value = "agentId") Integer agentId) { - QueryRecallResp queryRecallResp = new QueryRecallResp(); - Long startTime = System.currentTimeMillis(); - List solvedQueryRecallRespList = chatService.getSolvedQuery(queryText, agentId); - queryRecallResp.setSolvedQueryRecallRespList(solvedQueryRecallRespList); - queryRecallResp.setQueryTimeCost(System.currentTimeMillis() - startTime); - return solvedQueryRecallRespList; - } - } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/ChatService.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/ChatService.java index 47a4a4c57..5414b117f 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/ChatService.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/ChatService.java @@ -9,7 +9,6 @@ import com.tencent.supersonic.chat.api.pojo.response.ParseResp; import com.tencent.supersonic.chat.api.pojo.response.QueryResp; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp; -import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp; import com.tencent.supersonic.chat.persistence.dataobject.ChatDO; import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO; import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO; @@ -58,6 +57,4 @@ public interface ChatService { ChatParseDO getParseInfo(Long questionId, int parseId); Boolean deleteChatQuery(Long questionId); - - List getSolvedQuery(String queryText, Integer agentId); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/ChatServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/ChatServiceImpl.java index 6efcbd23d..d97f64f6e 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/ChatServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/ChatServiceImpl.java @@ -10,7 +10,6 @@ import com.tencent.supersonic.chat.api.pojo.response.ParseResp; import com.tencent.supersonic.chat.api.pojo.response.QueryResp; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp; -import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp; import com.tencent.supersonic.chat.persistence.dataobject.ChatDO; import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO; import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO; @@ -22,7 +21,6 @@ import com.tencent.supersonic.chat.service.ChatService; import com.tencent.supersonic.chat.utils.SimilarQueryManager; import com.tencent.supersonic.common.util.JsonUtil; import lombok.extern.slf4j.Slf4j; -import org.apache.commons.compress.utils.Lists; import org.springframework.context.annotation.Primary; import org.springframework.stereotype.Service; import org.springframework.util.CollectionUtils; @@ -239,32 +237,4 @@ public class ChatServiceImpl implements ChatService { return chatQueryRepository.deleteChatQuery(questionId); } - @Override - public List getSolvedQuery(String queryText, Integer agentId) { - //1. recall solved query by queryText - List solvedQueryRecallResps = solvedQueryManager.recallSimilarQuery(queryText, agentId); - if (CollectionUtils.isEmpty(solvedQueryRecallResps)) { - return Lists.newArrayList(); - } - List queryIds = solvedQueryRecallResps.stream() - .map(SimilarQueryRecallResp::getQueryId).collect(Collectors.toList()); - PageQueryInfoReq pageQueryInfoReq = new PageQueryInfoReq(); - pageQueryInfoReq.setIds(queryIds); - pageQueryInfoReq.setPageSize(100); - pageQueryInfoReq.setCurrent(1); - //2. remove low score query - int lowScoreThreshold = 3; - PageInfo queryRespPageInfo = chatQueryRepository.getChatQuery(pageQueryInfoReq, null); - List queryResps = queryRespPageInfo.getList(); - if (CollectionUtils.isEmpty(queryResps)) { - return Lists.newArrayList(); - } - Set lowScoreQueryIds = queryResps.stream().filter(queryResp -> - queryResp.getScore() != null && queryResp.getScore() <= lowScoreThreshold) - .map(QueryResp::getQuestionId).collect(Collectors.toSet()); - return solvedQueryRecallResps.stream().filter(solvedQueryRecallResp -> - !lowScoreQueryIds.contains(solvedQueryRecallResp.getQueryId())) - .collect(Collectors.toList()); - } - } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java index 2b1bc277b..7a01fc136 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java @@ -28,11 +28,11 @@ import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO; import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO; import com.tencent.supersonic.chat.persistence.dataobject.CostType; import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO; +import com.tencent.supersonic.chat.processor.execute.ExecuteResultProcessor; import com.tencent.supersonic.chat.processor.parse.ParseResultProcessor; import com.tencent.supersonic.chat.query.QueryManager; import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlQuery; import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery; -import com.tencent.supersonic.chat.processor.execute.ExecuteResultProcessor; import com.tencent.supersonic.chat.service.ChatService; import com.tencent.supersonic.chat.service.QueryService; import com.tencent.supersonic.chat.service.SemanticService; @@ -42,9 +42,9 @@ import com.tencent.supersonic.chat.utils.ComponentFactory; import com.tencent.supersonic.chat.utils.SimilarQueryManager; import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.common.pojo.QueryColumn; -import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.pojo.enums.DictWordType; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; +import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.DateUtils; @@ -77,6 +77,7 @@ import net.sf.jsqlparser.schema.Column; import org.apache.calcite.sql.parser.SqlParseException; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.tuple.Pair; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.context.annotation.Primary; @@ -264,19 +265,19 @@ public class QueryServiceImpl implements QueryService { SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode()); semanticQuery.setParseInfo(parseInfo); - List metrics = queryData.getMetrics().stream().map(o -> o.getName()).collect(Collectors.toList()); List fields = new ArrayList<>(); if (Objects.nonNull(parseInfo.getSqlInfo()) && StringUtils.isNotBlank(parseInfo.getSqlInfo().getCorrectS2SQL())) { String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL(); fields = SqlParserSelectHelper.getAllFields(correctorSql); } - if (CollectionUtils.isNotEmpty(fields) && !fields.containsAll(metrics) - && CollectionUtils.isNotEmpty(queryData.getMetrics())) { + if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode()) + && checkMetricReplace(fields, queryData.getMetrics())) { //replace metrics log.info("llm begin replace metrics!"); - replaceMetrics(parseInfo, metrics); - } else if (LLMSqlQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) { + SchemaElement metricToReplace = queryData.getMetrics().iterator().next(); + replaceMetrics(parseInfo, metricToReplace); + } else if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())) { log.info("llm begin revise filters!"); String correctorSql = reviseCorrectS2SQL(queryData, parseInfo); parseInfo.getSqlInfo().setCorrectS2SQL(correctorSql); @@ -304,6 +305,17 @@ public class QueryServiceImpl implements QueryService { return queryResult; } + private boolean checkMetricReplace(List oriFields, Set metrics) { + if (CollectionUtils.isEmpty(oriFields)) { + return false; + } + if (CollectionUtils.isEmpty(metrics)) { + return false; + } + List metricNames = metrics.stream().map(SchemaElement::getName).collect(Collectors.toList()); + return !oriFields.containsAll(metricNames); + } + public String reviseCorrectS2SQL(QueryDataReq queryData, SemanticParseInfo parseInfo) { Map> filedNameToValueMap = new HashMap<>(); Map> havingFiledNameToValueMap = new HashMap<>(); @@ -336,16 +348,16 @@ public class QueryServiceImpl implements QueryService { return correctorSql; } - private void replaceMetrics(SemanticParseInfo parseInfo, List metrics) { - List filteredMetrics = parseInfo.getMetrics().stream() - .map(o -> o.getName()).collect(Collectors.toList()); + private void replaceMetrics(SemanticParseInfo parseInfo, SchemaElement metric) { + List oriMetrics = parseInfo.getMetrics().stream() + .map(SchemaElement::getName).collect(Collectors.toList()); String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL(); log.info("before replaceMetrics:{}", correctorSql); - log.info("filteredMetrics:{},metrics:{}", filteredMetrics, metrics); - Map fieldMap = new HashMap<>(); - if (CollectionUtils.isNotEmpty(filteredMetrics) && CollectionUtils.isNotEmpty(metrics)) { - fieldMap.put(filteredMetrics.get(0), metrics.get(0)); - correctorSql = SqlParserReplaceHelper.replaceSelectFields(correctorSql, fieldMap); + log.info("filteredMetrics:{},metrics:{}", oriMetrics, metric); + Map> fieldMap = new HashMap<>(); + if (CollectionUtils.isNotEmpty(oriMetrics) && !oriMetrics.contains(metric.getName())) { + fieldMap.put(oriMetrics.get(0), Pair.of(metric.getName(), metric.getDefaultAgg())); + correctorSql = SqlParserReplaceHelper.replaceAggFields(correctorSql, fieldMap); } log.info("after replaceMetrics:{}", correctorSql); parseInfo.getSqlInfo().setCorrectS2SQL(correctorSql); @@ -541,9 +553,9 @@ public class QueryServiceImpl implements QueryService { if (CollectionUtils.isNotEmpty(queryData.getDimensions())) { parseInfo.setDimensions(queryData.getDimensions()); } - //if (CollectionUtils.isNotEmpty(queryData.getMetrics())) { - // parseInfo.setMetrics(queryData.getMetrics()); - //} + if (CollectionUtils.isNotEmpty(queryData.getMetrics())) { + parseInfo.setMetrics(queryData.getMetrics()); + } if (CollectionUtils.isNotEmpty(queryData.getDimensionFilters())) { parseInfo.setDimensionFilters(queryData.getDimensionFilters()); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/utils/QueryReqBuilder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/utils/QueryReqBuilder.java index 660b847ec..e808dfcc2 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/utils/QueryReqBuilder.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/utils/QueryReqBuilder.java @@ -146,7 +146,8 @@ public class QueryReqBuilder { List aggregators = new ArrayList<>(); if (metric != null) { String agg = ""; - if (Objects.isNull(aggregateType) || aggregateType.equals(AggregateTypeEnum.NONE)) { + if (Objects.isNull(aggregateType) || aggregateType.equals(AggregateTypeEnum.NONE) + || AggOperatorEnum.COUNT_DISTINCT.name().equalsIgnoreCase(metric.getDefaultAgg())) { if (StringUtils.isNotBlank(metric.getDefaultAgg())) { agg = metric.getDefaultAgg(); } diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/DataItem.java b/common/src/main/java/com/tencent/supersonic/common/pojo/DataItem.java index beb48e0a4..0cdf31708 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/DataItem.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/DataItem.java @@ -20,6 +20,8 @@ public class DataItem { private Long modelId; + private String defaultAgg; + public String getNewName() { return newName == null ? name : newName; } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelper.java index ec5ac9a69..8f094c6be 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelper.java @@ -1,11 +1,6 @@ package com.tencent.supersonic.common.util.jsqlparser; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; -import java.util.ArrayList; - +import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; import com.tencent.supersonic.common.util.StringUtil; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.JSQLParserException; @@ -32,10 +27,16 @@ import net.sf.jsqlparser.statement.select.SelectBody; import net.sf.jsqlparser.statement.select.SelectExpressionItem; import net.sf.jsqlparser.statement.select.SelectItem; import net.sf.jsqlparser.statement.select.SelectVisitorAdapter; -import net.sf.jsqlparser.statement.select.SubSelect; import net.sf.jsqlparser.statement.select.SetOperationList; +import net.sf.jsqlparser.statement.select.SubSelect; import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.tuple.Pair; import org.springframework.util.CollectionUtils; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; /** * Sql Parser replace Helper @@ -81,6 +82,39 @@ public class SqlParserReplaceHelper { return selectStatement.toString(); } + public static String replaceAggFields(String sql, Map> fieldNameToAggMap) { + Select selectStatement = SqlParserSelectHelper.getSelect(sql); + SelectBody selectBody = selectStatement.getSelectBody(); + if (!(selectBody instanceof PlainSelect)) { + return sql; + } + ((PlainSelect) selectBody).getSelectItems().stream().forEach(o -> { + SelectExpressionItem selectExpressionItem = (SelectExpressionItem) o; + if (selectExpressionItem.getExpression() instanceof Function) { + Function function = (Function) selectExpressionItem.getExpression(); + Column column = (Column) function.getParameters().getExpressions().get(0); + if (fieldNameToAggMap.containsKey(column.getColumnName())) { + Pair agg = fieldNameToAggMap.get(column.getColumnName()); + String field = agg.getKey(); + String func = agg.getRight(); + if (AggOperatorEnum.isCountDistinct(func)) { + function.setName("count"); + function.setDistinct(true); + } else { + function.setName(func); + } + List expressions = new ArrayList<>(); + expressions.add(new Column(field)); + function.getParameters().setExpressions(expressions); + if (Objects.nonNull(selectExpressionItem.getAlias()) && StringUtils.isNotBlank(field)) { + selectExpressionItem.getAlias().setName(field); + } + } + } + }); + return selectStatement.toString(); + } + public static String replaceValue(String sql, Map> filedNameToValueMap) { return replaceValue(sql, filedNameToValueMap, true); } diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelperTest.java index e5064a0c9..d3872992b 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelperTest.java @@ -1,14 +1,16 @@ package com.tencent.supersonic.common.util.jsqlparser; -import java.util.Set; -import java.util.HashSet; -import java.util.Collections; -import java.util.Map; -import java.util.HashMap; - +import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; +import org.apache.commons.lang3.tuple.Pair; import org.junit.Assert; import org.junit.jupiter.api.Test; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + /** * SqlParserReplaceHelperTest */ @@ -35,6 +37,18 @@ class SqlParserReplaceHelperTest { Assert.assertEquals("SELECT 维度1, 播放量1 FROM 数据库 WHERE (歌手名 = '张三') AND 数据日期 = '2023-11-17' GROUP BY 维度1", sql); } + @Test + void replaceAggField() { + String sql = "SELECT 维度1,sum(播放量) FROM 数据库 " + + "WHERE (歌手名 = '张三') AND 数据日期 = '2023-11-17' GROUP BY 维度1"; + Map> fieldMap = new HashMap<>(); + fieldMap.put("播放量", Pair.of("收听用户数", AggOperatorEnum.COUNT_DISTINCT.name())); + sql = SqlParserReplaceHelper.replaceAggFields(sql, fieldMap); + System.out.println(sql); + Assert.assertEquals("SELECT 维度1, count(DISTINCT 收听用户数) FROM 数据库 " + + "WHERE (歌手名 = '张三') AND 数据日期 = '2023-11-17' GROUP BY 维度1", sql); + } + @Test void replaceValue() { diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/MetricServiceImpl.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/MetricServiceImpl.java index cca99cff6..d8de16a2b 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/MetricServiceImpl.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/MetricServiceImpl.java @@ -401,9 +401,12 @@ public class MetricServiceImpl implements MetricService { } private DataItem getDataItem(MetricDO metricDO) { + MetricResp metricResp = MetricConverter.convert2MetricResp(metricDO, + new HashMap<>(), Lists.newArrayList()); return DataItem.builder().id(metricDO.getId()).name(metricDO.getName()) .bizName(metricDO.getBizName()) - .modelId(metricDO.getModelId()).type(TypeEnums.METRIC).build(); + .modelId(metricDO.getModelId()).type(TypeEnums.METRIC) + .defaultAgg(metricResp.getDefaultAgg()).build(); } }