(improvement)(chat) Switching metric supports default aggregation method of metric (#534)

Co-authored-by: jolunoluo
This commit is contained in:
LXW
2023-12-18 18:51:21 +08:00
committed by GitHub
parent 0c69651ef3
commit d7fafa361d
9 changed files with 99 additions and 79 deletions

View File

@@ -4,10 +4,8 @@ package com.tencent.supersonic.chat.rest;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryRecallResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
import com.tencent.supersonic.chat.service.ChatService;
import org.springframework.web.bind.annotation.GetMapping;
@@ -93,15 +91,4 @@ public class ChatController {
return chatService.queryShowCase(pageQueryInfoCommand, agentId);
}
@RequestMapping("/getSolvedQuery")
public List<SimilarQueryRecallResp> getSolvedQuery(@RequestParam(value = "queryText") String queryText,
@RequestParam(value = "agentId") Integer agentId) {
QueryRecallResp queryRecallResp = new QueryRecallResp();
Long startTime = System.currentTimeMillis();
List<SimilarQueryRecallResp> solvedQueryRecallRespList = chatService.getSolvedQuery(queryText, agentId);
queryRecallResp.setSolvedQueryRecallRespList(solvedQueryRecallRespList);
queryRecallResp.setQueryTimeCost(System.currentTimeMillis() - startTime);
return solvedQueryRecallRespList;
}
}

View File

@@ -9,7 +9,6 @@ import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
@@ -58,6 +57,4 @@ public interface ChatService {
ChatParseDO getParseInfo(Long questionId, int parseId);
Boolean deleteChatQuery(Long questionId);
List<SimilarQueryRecallResp> getSolvedQuery(String queryText, Integer agentId);
}

View File

@@ -10,7 +10,6 @@ import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
@@ -22,7 +21,6 @@ import com.tencent.supersonic.chat.service.ChatService;
import com.tencent.supersonic.chat.utils.SimilarQueryManager;
import com.tencent.supersonic.common.util.JsonUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.compress.utils.Lists;
import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
@@ -239,32 +237,4 @@ public class ChatServiceImpl implements ChatService {
return chatQueryRepository.deleteChatQuery(questionId);
}
@Override
public List<SimilarQueryRecallResp> getSolvedQuery(String queryText, Integer agentId) {
//1. recall solved query by queryText
List<SimilarQueryRecallResp> solvedQueryRecallResps = solvedQueryManager.recallSimilarQuery(queryText, agentId);
if (CollectionUtils.isEmpty(solvedQueryRecallResps)) {
return Lists.newArrayList();
}
List<Long> queryIds = solvedQueryRecallResps.stream()
.map(SimilarQueryRecallResp::getQueryId).collect(Collectors.toList());
PageQueryInfoReq pageQueryInfoReq = new PageQueryInfoReq();
pageQueryInfoReq.setIds(queryIds);
pageQueryInfoReq.setPageSize(100);
pageQueryInfoReq.setCurrent(1);
//2. remove low score query
int lowScoreThreshold = 3;
PageInfo<QueryResp> queryRespPageInfo = chatQueryRepository.getChatQuery(pageQueryInfoReq, null);
List<QueryResp> queryResps = queryRespPageInfo.getList();
if (CollectionUtils.isEmpty(queryResps)) {
return Lists.newArrayList();
}
Set<Long> lowScoreQueryIds = queryResps.stream().filter(queryResp ->
queryResp.getScore() != null && queryResp.getScore() <= lowScoreThreshold)
.map(QueryResp::getQuestionId).collect(Collectors.toSet());
return solvedQueryRecallResps.stream().filter(solvedQueryRecallResp ->
!lowScoreQueryIds.contains(solvedQueryRecallResp.getQueryId()))
.collect(Collectors.toList());
}
}

View File

@@ -28,11 +28,11 @@ import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.persistence.dataobject.CostType;
import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO;
import com.tencent.supersonic.chat.processor.execute.ExecuteResultProcessor;
import com.tencent.supersonic.chat.processor.parse.ParseResultProcessor;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.processor.execute.ExecuteResultProcessor;
import com.tencent.supersonic.chat.service.ChatService;
import com.tencent.supersonic.chat.service.QueryService;
import com.tencent.supersonic.chat.service.SemanticService;
@@ -42,9 +42,9 @@ import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.utils.SimilarQueryManager;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateUtils;
@@ -77,6 +77,7 @@ import net.sf.jsqlparser.schema.Column;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Primary;
@@ -264,19 +265,19 @@ public class QueryServiceImpl implements QueryService {
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
semanticQuery.setParseInfo(parseInfo);
List<String> metrics = queryData.getMetrics().stream().map(o -> o.getName()).collect(Collectors.toList());
List<String> fields = new ArrayList<>();
if (Objects.nonNull(parseInfo.getSqlInfo())
&& StringUtils.isNotBlank(parseInfo.getSqlInfo().getCorrectS2SQL())) {
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL();
fields = SqlParserSelectHelper.getAllFields(correctorSql);
}
if (CollectionUtils.isNotEmpty(fields) && !fields.containsAll(metrics)
&& CollectionUtils.isNotEmpty(queryData.getMetrics())) {
if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())
&& checkMetricReplace(fields, queryData.getMetrics())) {
//replace metrics
log.info("llm begin replace metrics!");
replaceMetrics(parseInfo, metrics);
} else if (LLMSqlQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
SchemaElement metricToReplace = queryData.getMetrics().iterator().next();
replaceMetrics(parseInfo, metricToReplace);
} else if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())) {
log.info("llm begin revise filters!");
String correctorSql = reviseCorrectS2SQL(queryData, parseInfo);
parseInfo.getSqlInfo().setCorrectS2SQL(correctorSql);
@@ -304,6 +305,17 @@ public class QueryServiceImpl implements QueryService {
return queryResult;
}
private boolean checkMetricReplace(List<String> oriFields, Set<SchemaElement> metrics) {
if (CollectionUtils.isEmpty(oriFields)) {
return false;
}
if (CollectionUtils.isEmpty(metrics)) {
return false;
}
List<String> metricNames = metrics.stream().map(SchemaElement::getName).collect(Collectors.toList());
return !oriFields.containsAll(metricNames);
}
public String reviseCorrectS2SQL(QueryDataReq queryData, SemanticParseInfo parseInfo) {
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
Map<String, Map<String, String>> havingFiledNameToValueMap = new HashMap<>();
@@ -336,16 +348,16 @@ public class QueryServiceImpl implements QueryService {
return correctorSql;
}
private void replaceMetrics(SemanticParseInfo parseInfo, List<String> metrics) {
List<String> filteredMetrics = parseInfo.getMetrics().stream()
.map(o -> o.getName()).collect(Collectors.toList());
private void replaceMetrics(SemanticParseInfo parseInfo, SchemaElement metric) {
List<String> oriMetrics = parseInfo.getMetrics().stream()
.map(SchemaElement::getName).collect(Collectors.toList());
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL();
log.info("before replaceMetrics:{}", correctorSql);
log.info("filteredMetrics:{},metrics:{}", filteredMetrics, metrics);
Map<String, String> fieldMap = new HashMap<>();
if (CollectionUtils.isNotEmpty(filteredMetrics) && CollectionUtils.isNotEmpty(metrics)) {
fieldMap.put(filteredMetrics.get(0), metrics.get(0));
correctorSql = SqlParserReplaceHelper.replaceSelectFields(correctorSql, fieldMap);
log.info("filteredMetrics:{},metrics:{}", oriMetrics, metric);
Map<String, Pair<String, String>> fieldMap = new HashMap<>();
if (CollectionUtils.isNotEmpty(oriMetrics) && !oriMetrics.contains(metric.getName())) {
fieldMap.put(oriMetrics.get(0), Pair.of(metric.getName(), metric.getDefaultAgg()));
correctorSql = SqlParserReplaceHelper.replaceAggFields(correctorSql, fieldMap);
}
log.info("after replaceMetrics:{}", correctorSql);
parseInfo.getSqlInfo().setCorrectS2SQL(correctorSql);
@@ -541,9 +553,9 @@ public class QueryServiceImpl implements QueryService {
if (CollectionUtils.isNotEmpty(queryData.getDimensions())) {
parseInfo.setDimensions(queryData.getDimensions());
}
//if (CollectionUtils.isNotEmpty(queryData.getMetrics())) {
// parseInfo.setMetrics(queryData.getMetrics());
//}
if (CollectionUtils.isNotEmpty(queryData.getMetrics())) {
parseInfo.setMetrics(queryData.getMetrics());
}
if (CollectionUtils.isNotEmpty(queryData.getDimensionFilters())) {
parseInfo.setDimensionFilters(queryData.getDimensionFilters());
}

View File

@@ -146,7 +146,8 @@ public class QueryReqBuilder {
List<Aggregator> aggregators = new ArrayList<>();
if (metric != null) {
String agg = "";
if (Objects.isNull(aggregateType) || aggregateType.equals(AggregateTypeEnum.NONE)) {
if (Objects.isNull(aggregateType) || aggregateType.equals(AggregateTypeEnum.NONE)
|| AggOperatorEnum.COUNT_DISTINCT.name().equalsIgnoreCase(metric.getDefaultAgg())) {
if (StringUtils.isNotBlank(metric.getDefaultAgg())) {
agg = metric.getDefaultAgg();
}

View File

@@ -20,6 +20,8 @@ public class DataItem {
private Long modelId;
private String defaultAgg;
public String getNewName() {
return newName == null ? name : newName;
}

View File

@@ -1,11 +1,6 @@
package com.tencent.supersonic.common.util.jsqlparser;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.ArrayList;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.util.StringUtil;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
@@ -32,10 +27,16 @@ import net.sf.jsqlparser.statement.select.SelectBody;
import net.sf.jsqlparser.statement.select.SelectExpressionItem;
import net.sf.jsqlparser.statement.select.SelectItem;
import net.sf.jsqlparser.statement.select.SelectVisitorAdapter;
import net.sf.jsqlparser.statement.select.SubSelect;
import net.sf.jsqlparser.statement.select.SetOperationList;
import net.sf.jsqlparser.statement.select.SubSelect;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
/**
* Sql Parser replace Helper
@@ -81,6 +82,39 @@ public class SqlParserReplaceHelper {
return selectStatement.toString();
}
public static String replaceAggFields(String sql, Map<String, Pair<String, String>> fieldNameToAggMap) {
Select selectStatement = SqlParserSelectHelper.getSelect(sql);
SelectBody selectBody = selectStatement.getSelectBody();
if (!(selectBody instanceof PlainSelect)) {
return sql;
}
((PlainSelect) selectBody).getSelectItems().stream().forEach(o -> {
SelectExpressionItem selectExpressionItem = (SelectExpressionItem) o;
if (selectExpressionItem.getExpression() instanceof Function) {
Function function = (Function) selectExpressionItem.getExpression();
Column column = (Column) function.getParameters().getExpressions().get(0);
if (fieldNameToAggMap.containsKey(column.getColumnName())) {
Pair<String, String> agg = fieldNameToAggMap.get(column.getColumnName());
String field = agg.getKey();
String func = agg.getRight();
if (AggOperatorEnum.isCountDistinct(func)) {
function.setName("count");
function.setDistinct(true);
} else {
function.setName(func);
}
List<Expression> expressions = new ArrayList<>();
expressions.add(new Column(field));
function.getParameters().setExpressions(expressions);
if (Objects.nonNull(selectExpressionItem.getAlias()) && StringUtils.isNotBlank(field)) {
selectExpressionItem.getAlias().setName(field);
}
}
}
});
return selectStatement.toString();
}
public static String replaceValue(String sql, Map<String, Map<String, String>> filedNameToValueMap) {
return replaceValue(sql, filedNameToValueMap, true);
}

View File

@@ -1,14 +1,16 @@
package com.tencent.supersonic.common.util.jsqlparser;
import java.util.Set;
import java.util.HashSet;
import java.util.Collections;
import java.util.Map;
import java.util.HashMap;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import org.apache.commons.lang3.tuple.Pair;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
/**
* SqlParserReplaceHelperTest
*/
@@ -35,6 +37,18 @@ class SqlParserReplaceHelperTest {
Assert.assertEquals("SELECT 维度1, 播放量1 FROM 数据库 WHERE (歌手名 = '张三') AND 数据日期 = '2023-11-17' GROUP BY 维度1", sql);
}
@Test
void replaceAggField() {
String sql = "SELECT 维度1,sum(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 = '2023-11-17' GROUP BY 维度1";
Map<String, Pair<String, String>> fieldMap = new HashMap<>();
fieldMap.put("播放量", Pair.of("收听用户数", AggOperatorEnum.COUNT_DISTINCT.name()));
sql = SqlParserReplaceHelper.replaceAggFields(sql, fieldMap);
System.out.println(sql);
Assert.assertEquals("SELECT 维度1, count(DISTINCT 收听用户数) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 = '2023-11-17' GROUP BY 维度1", sql);
}
@Test
void replaceValue() {

View File

@@ -401,9 +401,12 @@ public class MetricServiceImpl implements MetricService {
}
private DataItem getDataItem(MetricDO metricDO) {
MetricResp metricResp = MetricConverter.convert2MetricResp(metricDO,
new HashMap<>(), Lists.newArrayList());
return DataItem.builder().id(metricDO.getId()).name(metricDO.getName())
.bizName(metricDO.getBizName())
.modelId(metricDO.getModelId()).type(TypeEnums.METRIC).build();
.modelId(metricDO.getModelId()).type(TypeEnums.METRIC)
.defaultAgg(metricResp.getDefaultAgg()).build();
}
}