mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(chat) Switching metric supports default aggregation method of metric (#534)
Co-authored-by: jolunoluo
This commit is contained in:
@@ -4,10 +4,8 @@ package com.tencent.supersonic.chat.rest;
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryRecallResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
|
||||
import com.tencent.supersonic.chat.service.ChatService;
|
||||
import org.springframework.web.bind.annotation.GetMapping;
|
||||
@@ -93,15 +91,4 @@ public class ChatController {
|
||||
return chatService.queryShowCase(pageQueryInfoCommand, agentId);
|
||||
}
|
||||
|
||||
@RequestMapping("/getSolvedQuery")
|
||||
public List<SimilarQueryRecallResp> getSolvedQuery(@RequestParam(value = "queryText") String queryText,
|
||||
@RequestParam(value = "agentId") Integer agentId) {
|
||||
QueryRecallResp queryRecallResp = new QueryRecallResp();
|
||||
Long startTime = System.currentTimeMillis();
|
||||
List<SimilarQueryRecallResp> solvedQueryRecallRespList = chatService.getSolvedQuery(queryText, agentId);
|
||||
queryRecallResp.setSolvedQueryRecallRespList(solvedQueryRecallRespList);
|
||||
queryRecallResp.setQueryTimeCost(System.currentTimeMillis() - startTime);
|
||||
return solvedQueryRecallRespList;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
|
||||
@@ -58,6 +57,4 @@ public interface ChatService {
|
||||
ChatParseDO getParseInfo(Long questionId, int parseId);
|
||||
|
||||
Boolean deleteChatQuery(Long questionId);
|
||||
|
||||
List<SimilarQueryRecallResp> getSolvedQuery(String queryText, Integer agentId);
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
|
||||
@@ -22,7 +21,6 @@ import com.tencent.supersonic.chat.service.ChatService;
|
||||
import com.tencent.supersonic.chat.utils.SimilarQueryManager;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.compress.utils.Lists;
|
||||
import org.springframework.context.annotation.Primary;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
@@ -239,32 +237,4 @@ public class ChatServiceImpl implements ChatService {
|
||||
return chatQueryRepository.deleteChatQuery(questionId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SimilarQueryRecallResp> getSolvedQuery(String queryText, Integer agentId) {
|
||||
//1. recall solved query by queryText
|
||||
List<SimilarQueryRecallResp> solvedQueryRecallResps = solvedQueryManager.recallSimilarQuery(queryText, agentId);
|
||||
if (CollectionUtils.isEmpty(solvedQueryRecallResps)) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
List<Long> queryIds = solvedQueryRecallResps.stream()
|
||||
.map(SimilarQueryRecallResp::getQueryId).collect(Collectors.toList());
|
||||
PageQueryInfoReq pageQueryInfoReq = new PageQueryInfoReq();
|
||||
pageQueryInfoReq.setIds(queryIds);
|
||||
pageQueryInfoReq.setPageSize(100);
|
||||
pageQueryInfoReq.setCurrent(1);
|
||||
//2. remove low score query
|
||||
int lowScoreThreshold = 3;
|
||||
PageInfo<QueryResp> queryRespPageInfo = chatQueryRepository.getChatQuery(pageQueryInfoReq, null);
|
||||
List<QueryResp> queryResps = queryRespPageInfo.getList();
|
||||
if (CollectionUtils.isEmpty(queryResps)) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
Set<Long> lowScoreQueryIds = queryResps.stream().filter(queryResp ->
|
||||
queryResp.getScore() != null && queryResp.getScore() <= lowScoreThreshold)
|
||||
.map(QueryResp::getQuestionId).collect(Collectors.toSet());
|
||||
return solvedQueryRecallResps.stream().filter(solvedQueryRecallResp ->
|
||||
!lowScoreQueryIds.contains(solvedQueryRecallResp.getQueryId()))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -28,11 +28,11 @@ import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.CostType;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO;
|
||||
import com.tencent.supersonic.chat.processor.execute.ExecuteResultProcessor;
|
||||
import com.tencent.supersonic.chat.processor.parse.ParseResultProcessor;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.chat.processor.execute.ExecuteResultProcessor;
|
||||
import com.tencent.supersonic.chat.service.ChatService;
|
||||
import com.tencent.supersonic.chat.service.QueryService;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
@@ -42,9 +42,9 @@ import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.utils.SimilarQueryManager;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
@@ -77,6 +77,7 @@ import net.sf.jsqlparser.schema.Column;
|
||||
import org.apache.calcite.sql.parser.SqlParseException;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.context.annotation.Primary;
|
||||
@@ -264,19 +265,19 @@ public class QueryServiceImpl implements QueryService {
|
||||
|
||||
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
|
||||
semanticQuery.setParseInfo(parseInfo);
|
||||
List<String> metrics = queryData.getMetrics().stream().map(o -> o.getName()).collect(Collectors.toList());
|
||||
List<String> fields = new ArrayList<>();
|
||||
if (Objects.nonNull(parseInfo.getSqlInfo())
|
||||
&& StringUtils.isNotBlank(parseInfo.getSqlInfo().getCorrectS2SQL())) {
|
||||
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
fields = SqlParserSelectHelper.getAllFields(correctorSql);
|
||||
}
|
||||
if (CollectionUtils.isNotEmpty(fields) && !fields.containsAll(metrics)
|
||||
&& CollectionUtils.isNotEmpty(queryData.getMetrics())) {
|
||||
if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())
|
||||
&& checkMetricReplace(fields, queryData.getMetrics())) {
|
||||
//replace metrics
|
||||
log.info("llm begin replace metrics!");
|
||||
replaceMetrics(parseInfo, metrics);
|
||||
} else if (LLMSqlQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
|
||||
SchemaElement metricToReplace = queryData.getMetrics().iterator().next();
|
||||
replaceMetrics(parseInfo, metricToReplace);
|
||||
} else if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())) {
|
||||
log.info("llm begin revise filters!");
|
||||
String correctorSql = reviseCorrectS2SQL(queryData, parseInfo);
|
||||
parseInfo.getSqlInfo().setCorrectS2SQL(correctorSql);
|
||||
@@ -304,6 +305,17 @@ public class QueryServiceImpl implements QueryService {
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
private boolean checkMetricReplace(List<String> oriFields, Set<SchemaElement> metrics) {
|
||||
if (CollectionUtils.isEmpty(oriFields)) {
|
||||
return false;
|
||||
}
|
||||
if (CollectionUtils.isEmpty(metrics)) {
|
||||
return false;
|
||||
}
|
||||
List<String> metricNames = metrics.stream().map(SchemaElement::getName).collect(Collectors.toList());
|
||||
return !oriFields.containsAll(metricNames);
|
||||
}
|
||||
|
||||
public String reviseCorrectS2SQL(QueryDataReq queryData, SemanticParseInfo parseInfo) {
|
||||
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
|
||||
Map<String, Map<String, String>> havingFiledNameToValueMap = new HashMap<>();
|
||||
@@ -336,16 +348,16 @@ public class QueryServiceImpl implements QueryService {
|
||||
return correctorSql;
|
||||
}
|
||||
|
||||
private void replaceMetrics(SemanticParseInfo parseInfo, List<String> metrics) {
|
||||
List<String> filteredMetrics = parseInfo.getMetrics().stream()
|
||||
.map(o -> o.getName()).collect(Collectors.toList());
|
||||
private void replaceMetrics(SemanticParseInfo parseInfo, SchemaElement metric) {
|
||||
List<String> oriMetrics = parseInfo.getMetrics().stream()
|
||||
.map(SchemaElement::getName).collect(Collectors.toList());
|
||||
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
log.info("before replaceMetrics:{}", correctorSql);
|
||||
log.info("filteredMetrics:{},metrics:{}", filteredMetrics, metrics);
|
||||
Map<String, String> fieldMap = new HashMap<>();
|
||||
if (CollectionUtils.isNotEmpty(filteredMetrics) && CollectionUtils.isNotEmpty(metrics)) {
|
||||
fieldMap.put(filteredMetrics.get(0), metrics.get(0));
|
||||
correctorSql = SqlParserReplaceHelper.replaceSelectFields(correctorSql, fieldMap);
|
||||
log.info("filteredMetrics:{},metrics:{}", oriMetrics, metric);
|
||||
Map<String, Pair<String, String>> fieldMap = new HashMap<>();
|
||||
if (CollectionUtils.isNotEmpty(oriMetrics) && !oriMetrics.contains(metric.getName())) {
|
||||
fieldMap.put(oriMetrics.get(0), Pair.of(metric.getName(), metric.getDefaultAgg()));
|
||||
correctorSql = SqlParserReplaceHelper.replaceAggFields(correctorSql, fieldMap);
|
||||
}
|
||||
log.info("after replaceMetrics:{}", correctorSql);
|
||||
parseInfo.getSqlInfo().setCorrectS2SQL(correctorSql);
|
||||
@@ -541,9 +553,9 @@ public class QueryServiceImpl implements QueryService {
|
||||
if (CollectionUtils.isNotEmpty(queryData.getDimensions())) {
|
||||
parseInfo.setDimensions(queryData.getDimensions());
|
||||
}
|
||||
//if (CollectionUtils.isNotEmpty(queryData.getMetrics())) {
|
||||
// parseInfo.setMetrics(queryData.getMetrics());
|
||||
//}
|
||||
if (CollectionUtils.isNotEmpty(queryData.getMetrics())) {
|
||||
parseInfo.setMetrics(queryData.getMetrics());
|
||||
}
|
||||
if (CollectionUtils.isNotEmpty(queryData.getDimensionFilters())) {
|
||||
parseInfo.setDimensionFilters(queryData.getDimensionFilters());
|
||||
}
|
||||
|
||||
@@ -146,7 +146,8 @@ public class QueryReqBuilder {
|
||||
List<Aggregator> aggregators = new ArrayList<>();
|
||||
if (metric != null) {
|
||||
String agg = "";
|
||||
if (Objects.isNull(aggregateType) || aggregateType.equals(AggregateTypeEnum.NONE)) {
|
||||
if (Objects.isNull(aggregateType) || aggregateType.equals(AggregateTypeEnum.NONE)
|
||||
|| AggOperatorEnum.COUNT_DISTINCT.name().equalsIgnoreCase(metric.getDefaultAgg())) {
|
||||
if (StringUtils.isNotBlank(metric.getDefaultAgg())) {
|
||||
agg = metric.getDefaultAgg();
|
||||
}
|
||||
|
||||
@@ -20,6 +20,8 @@ public class DataItem {
|
||||
|
||||
private Long modelId;
|
||||
|
||||
private String defaultAgg;
|
||||
|
||||
public String getNewName() {
|
||||
return newName == null ? name : newName;
|
||||
}
|
||||
|
||||
@@ -1,11 +1,6 @@
|
||||
package com.tencent.supersonic.common.util.jsqlparser;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.ArrayList;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||
import com.tencent.supersonic.common.util.StringUtil;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
@@ -32,10 +27,16 @@ import net.sf.jsqlparser.statement.select.SelectBody;
|
||||
import net.sf.jsqlparser.statement.select.SelectExpressionItem;
|
||||
import net.sf.jsqlparser.statement.select.SelectItem;
|
||||
import net.sf.jsqlparser.statement.select.SelectVisitorAdapter;
|
||||
import net.sf.jsqlparser.statement.select.SubSelect;
|
||||
import net.sf.jsqlparser.statement.select.SetOperationList;
|
||||
import net.sf.jsqlparser.statement.select.SubSelect;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* Sql Parser replace Helper
|
||||
@@ -81,6 +82,39 @@ public class SqlParserReplaceHelper {
|
||||
return selectStatement.toString();
|
||||
}
|
||||
|
||||
public static String replaceAggFields(String sql, Map<String, Pair<String, String>> fieldNameToAggMap) {
|
||||
Select selectStatement = SqlParserSelectHelper.getSelect(sql);
|
||||
SelectBody selectBody = selectStatement.getSelectBody();
|
||||
if (!(selectBody instanceof PlainSelect)) {
|
||||
return sql;
|
||||
}
|
||||
((PlainSelect) selectBody).getSelectItems().stream().forEach(o -> {
|
||||
SelectExpressionItem selectExpressionItem = (SelectExpressionItem) o;
|
||||
if (selectExpressionItem.getExpression() instanceof Function) {
|
||||
Function function = (Function) selectExpressionItem.getExpression();
|
||||
Column column = (Column) function.getParameters().getExpressions().get(0);
|
||||
if (fieldNameToAggMap.containsKey(column.getColumnName())) {
|
||||
Pair<String, String> agg = fieldNameToAggMap.get(column.getColumnName());
|
||||
String field = agg.getKey();
|
||||
String func = agg.getRight();
|
||||
if (AggOperatorEnum.isCountDistinct(func)) {
|
||||
function.setName("count");
|
||||
function.setDistinct(true);
|
||||
} else {
|
||||
function.setName(func);
|
||||
}
|
||||
List<Expression> expressions = new ArrayList<>();
|
||||
expressions.add(new Column(field));
|
||||
function.getParameters().setExpressions(expressions);
|
||||
if (Objects.nonNull(selectExpressionItem.getAlias()) && StringUtils.isNotBlank(field)) {
|
||||
selectExpressionItem.getAlias().setName(field);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
return selectStatement.toString();
|
||||
}
|
||||
|
||||
public static String replaceValue(String sql, Map<String, Map<String, String>> filedNameToValueMap) {
|
||||
return replaceValue(sql, filedNameToValueMap, true);
|
||||
}
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
package com.tencent.supersonic.common.util.jsqlparser;
|
||||
|
||||
import java.util.Set;
|
||||
import java.util.HashSet;
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
import java.util.HashMap;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* SqlParserReplaceHelperTest
|
||||
*/
|
||||
@@ -35,6 +37,18 @@ class SqlParserReplaceHelperTest {
|
||||
Assert.assertEquals("SELECT 维度1, 播放量1 FROM 数据库 WHERE (歌手名 = '张三') AND 数据日期 = '2023-11-17' GROUP BY 维度1", sql);
|
||||
}
|
||||
|
||||
@Test
|
||||
void replaceAggField() {
|
||||
String sql = "SELECT 维度1,sum(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 = '2023-11-17' GROUP BY 维度1";
|
||||
Map<String, Pair<String, String>> fieldMap = new HashMap<>();
|
||||
fieldMap.put("播放量", Pair.of("收听用户数", AggOperatorEnum.COUNT_DISTINCT.name()));
|
||||
sql = SqlParserReplaceHelper.replaceAggFields(sql, fieldMap);
|
||||
System.out.println(sql);
|
||||
Assert.assertEquals("SELECT 维度1, count(DISTINCT 收听用户数) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 = '2023-11-17' GROUP BY 维度1", sql);
|
||||
}
|
||||
|
||||
@Test
|
||||
void replaceValue() {
|
||||
|
||||
|
||||
@@ -401,9 +401,12 @@ public class MetricServiceImpl implements MetricService {
|
||||
}
|
||||
|
||||
private DataItem getDataItem(MetricDO metricDO) {
|
||||
MetricResp metricResp = MetricConverter.convert2MetricResp(metricDO,
|
||||
new HashMap<>(), Lists.newArrayList());
|
||||
return DataItem.builder().id(metricDO.getId()).name(metricDO.getName())
|
||||
.bizName(metricDO.getBizName())
|
||||
.modelId(metricDO.getModelId()).type(TypeEnums.METRIC).build();
|
||||
.modelId(metricDO.getModelId()).type(TypeEnums.METRIC)
|
||||
.defaultAgg(metricResp.getDefaultAgg()).build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user