mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 03:58:14 +00:00
(improvment)(chat) if exist count() in dsl,set query to NATIVE and only order by field and group by field can add to select (#206)
This commit is contained in:
@@ -44,16 +44,16 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
|||||||
|
|
||||||
protected void addFieldsToSelect(SemanticCorrectInfo semanticCorrectInfo, String sql) {
|
protected void addFieldsToSelect(SemanticCorrectInfo semanticCorrectInfo, String sql) {
|
||||||
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(sql));
|
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(sql));
|
||||||
Set<String> whereFields = new HashSet<>(SqlParserSelectHelper.getWhereFields(sql));
|
Set<String> needAddFields = new HashSet<>(SqlParserSelectHelper.getGroupByFields(sql));
|
||||||
|
needAddFields.addAll(SqlParserSelectHelper.getOrderByFields(sql));
|
||||||
|
|
||||||
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(whereFields)) {
|
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(needAddFields)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
whereFields.addAll(SqlParserSelectHelper.getOrderByFields(sql));
|
needAddFields.removeAll(selectFields);
|
||||||
whereFields.removeAll(selectFields);
|
needAddFields.remove(DateUtils.DATE_FIELD);
|
||||||
whereFields.remove(DateUtils.DATE_FIELD);
|
String replaceFields = SqlParserAddHelper.addFieldsToSelect(sql, new ArrayList<>(needAddFields));
|
||||||
String replaceFields = SqlParserAddHelper.addFieldsToSelect(sql, new ArrayList<>(whereFields));
|
|
||||||
semanticCorrectInfo.setSql(replaceFields);
|
semanticCorrectInfo.setSql(replaceFields);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
|
|||||||
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||||
import com.tencent.supersonic.chat.query.QueryManager;
|
import com.tencent.supersonic.chat.query.QueryManager;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.dsl.DslQuery;
|
||||||
import com.tencent.supersonic.chat.service.SemanticService;
|
import com.tencent.supersonic.chat.service.SemanticService;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -22,7 +23,8 @@ public class EntityInfoExecuteResponder implements ExecuteResponder {
|
|||||||
if (semanticParseInfo == null || semanticParseInfo.getModelId() <= 0L) {
|
if (semanticParseInfo == null || semanticParseInfo.getModelId() <= 0L) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (QueryManager.isPluginQuery(semanticParseInfo.getQueryMode())) {
|
String queryMode = semanticParseInfo.getQueryMode();
|
||||||
|
if (QueryManager.isPluginQuery(queryMode) && !DslQuery.QUERY_MODE.equals(queryMode)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ public class EntityInfoParseResponder implements ParseResponder {
|
|||||||
QueryReq queryReq = queryContext.getRequest();
|
QueryReq queryReq = queryContext.getRequest();
|
||||||
selectedParses.forEach(parseInfo -> {
|
selectedParses.forEach(parseInfo -> {
|
||||||
if (QueryManager.isPluginQuery(parseInfo.getQueryMode())
|
if (QueryManager.isPluginQuery(parseInfo.getQueryMode())
|
||||||
&& !parseInfo.getQueryMode().equals(DslQuery.QUERY_MODE)) {
|
&& !DslQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
//1. set entity info
|
//1. set entity info
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ public class ChatServiceImpl implements ChatService {
|
|||||||
private SolvedQueryManager solvedQueryManager;
|
private SolvedQueryManager solvedQueryManager;
|
||||||
|
|
||||||
public ChatServiceImpl(ChatContextRepository chatContextRepository, ChatRepository chatRepository,
|
public ChatServiceImpl(ChatContextRepository chatContextRepository, ChatRepository chatRepository,
|
||||||
ChatQueryRepository chatQueryRepository, SolvedQueryManager solvedQueryManager) {
|
ChatQueryRepository chatQueryRepository, SolvedQueryManager solvedQueryManager) {
|
||||||
this.chatContextRepository = chatContextRepository;
|
this.chatContextRepository = chatContextRepository;
|
||||||
this.chatRepository = chatRepository;
|
this.chatRepository = chatRepository;
|
||||||
this.chatQueryRepository = chatQueryRepository;
|
this.chatQueryRepository = chatQueryRepository;
|
||||||
@@ -174,9 +174,9 @@ public class ChatServiceImpl implements ChatService {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void batchAddParse(ChatContext chatCtx, QueryReq queryReq,
|
public void batchAddParse(ChatContext chatCtx, QueryReq queryReq,
|
||||||
ParseResp parseResult,
|
ParseResp parseResult,
|
||||||
List<SemanticParseInfo> candidateParses,
|
List<SemanticParseInfo> candidateParses,
|
||||||
List<SemanticParseInfo> selectedParses) {
|
List<SemanticParseInfo> selectedParses) {
|
||||||
chatQueryRepository.batchSaveParseInfo(chatCtx, queryReq, parseResult, candidateParses, selectedParses);
|
chatQueryRepository.batchSaveParseInfo(chatCtx, queryReq, parseResult, candidateParses, selectedParses);
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -205,6 +205,8 @@ public class ChatServiceImpl implements ChatService {
|
|||||||
List<SolvedQueryRecallResp> solvedQueryRecallResps = solvedQueryManager.recallSolvedQuery(queryText, agentId);
|
List<SolvedQueryRecallResp> solvedQueryRecallResps = solvedQueryManager.recallSolvedQuery(queryText, agentId);
|
||||||
List<Long> queryIds = solvedQueryRecallResps.stream()
|
List<Long> queryIds = solvedQueryRecallResps.stream()
|
||||||
.map(SolvedQueryRecallResp::getQueryId).collect(Collectors.toList());
|
.map(SolvedQueryRecallResp::getQueryId).collect(Collectors.toList());
|
||||||
|
List<Long> queryIds = solvedQueryRecallResps.stream().map(SolvedQueryRecallResp::getQueryId)
|
||||||
|
.collect(Collectors.toList());
|
||||||
PageQueryInfoReq pageQueryInfoReq = new PageQueryInfoReq();
|
PageQueryInfoReq pageQueryInfoReq = new PageQueryInfoReq();
|
||||||
pageQueryInfoReq.setIds(queryIds);
|
pageQueryInfoReq.setIds(queryIds);
|
||||||
pageQueryInfoReq.setPageSize(100);
|
pageQueryInfoReq.setPageSize(100);
|
||||||
@@ -219,7 +221,7 @@ public class ChatServiceImpl implements ChatService {
|
|||||||
queryResp.getScore() != null && queryResp.getScore() <= lowScoreThreshold)
|
queryResp.getScore() != null && queryResp.getScore() <= lowScoreThreshold)
|
||||||
.map(QueryResp::getQuestionId).collect(Collectors.toSet());
|
.map(QueryResp::getQuestionId).collect(Collectors.toSet());
|
||||||
return solvedQueryRecallResps.stream().filter(solvedQueryRecallResp ->
|
return solvedQueryRecallResps.stream().filter(solvedQueryRecallResp ->
|
||||||
!lowScoreQueryIds.contains(solvedQueryRecallResp.getQueryId()))
|
!lowScoreQueryIds.contains(solvedQueryRecallResp.getQueryId()))
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,19 +0,0 @@
|
|||||||
package com.tencent.supersonic.common.util.jsqlparser;
|
|
||||||
|
|
||||||
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
|
|
||||||
import net.sf.jsqlparser.expression.Function;
|
|
||||||
|
|
||||||
public class AggregateFunctionVisitor extends ExpressionVisitorAdapter {
|
|
||||||
|
|
||||||
private boolean hasAggregateFunction = false;
|
|
||||||
|
|
||||||
public boolean hasAggregateFunction() {
|
|
||||||
return hasAggregateFunction;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void visit(Function function) {
|
|
||||||
super.visit(function);
|
|
||||||
hasAggregateFunction = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
package com.tencent.supersonic.common.util.jsqlparser;
|
||||||
|
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.Set;
|
||||||
|
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
|
||||||
|
import net.sf.jsqlparser.expression.Function;
|
||||||
|
|
||||||
|
public class FunctionVisitor extends ExpressionVisitorAdapter {
|
||||||
|
|
||||||
|
private Set<String> functionNames = new HashSet<>();
|
||||||
|
|
||||||
|
public Set<String> getFunctionNames() {
|
||||||
|
return functionNames;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void visit(Function function) {
|
||||||
|
super.visit(function);
|
||||||
|
functionNames.add(function.getName());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,8 +1,10 @@
|
|||||||
package com.tencent.supersonic.common.util.jsqlparser;
|
package com.tencent.supersonic.common.util.jsqlparser;
|
||||||
|
|
||||||
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import net.sf.jsqlparser.expression.Expression;
|
import net.sf.jsqlparser.expression.Expression;
|
||||||
import net.sf.jsqlparser.expression.Function;
|
import net.sf.jsqlparser.expression.Function;
|
||||||
@@ -13,6 +15,7 @@ import net.sf.jsqlparser.statement.select.Select;
|
|||||||
import net.sf.jsqlparser.statement.select.SelectBody;
|
import net.sf.jsqlparser.statement.select.SelectBody;
|
||||||
import net.sf.jsqlparser.statement.select.SelectItem;
|
import net.sf.jsqlparser.statement.select.SelectItem;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sql Parser Select function Helper
|
* Sql Parser Select function Helper
|
||||||
@@ -21,30 +24,34 @@ import org.apache.commons.lang3.StringUtils;
|
|||||||
public class SqlParserSelectFunctionHelper {
|
public class SqlParserSelectFunctionHelper {
|
||||||
|
|
||||||
public static boolean hasAggregateFunction(String sql) {
|
public static boolean hasAggregateFunction(String sql) {
|
||||||
if (hasFunction(sql)) {
|
if (!CollectionUtils.isEmpty(getFunctions(sql))) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return SqlParserSelectHelper.hasGroupBy(sql);
|
return SqlParserSelectHelper.hasGroupBy(sql);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static boolean hasFunction(String sql) {
|
public static boolean hasFunction(String sql, String functionName) {
|
||||||
|
Set<String> functions = getFunctions(sql);
|
||||||
|
if (!CollectionUtils.isEmpty(functions)) {
|
||||||
|
return functions.stream().anyMatch(function -> function.equalsIgnoreCase(functionName));
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Set<String> getFunctions(String sql) {
|
||||||
Select selectStatement = SqlParserSelectHelper.getSelect(sql);
|
Select selectStatement = SqlParserSelectHelper.getSelect(sql);
|
||||||
SelectBody selectBody = selectStatement.getSelectBody();
|
SelectBody selectBody = selectStatement.getSelectBody();
|
||||||
|
|
||||||
if (!(selectBody instanceof PlainSelect)) {
|
if (!(selectBody instanceof PlainSelect)) {
|
||||||
return false;
|
return new HashSet<>();
|
||||||
}
|
}
|
||||||
PlainSelect plainSelect = (PlainSelect) selectBody;
|
PlainSelect plainSelect = (PlainSelect) selectBody;
|
||||||
List<SelectItem> selectItems = plainSelect.getSelectItems();
|
List<SelectItem> selectItems = plainSelect.getSelectItems();
|
||||||
AggregateFunctionVisitor visitor = new AggregateFunctionVisitor();
|
FunctionVisitor visitor = new FunctionVisitor();
|
||||||
for (SelectItem selectItem : selectItems) {
|
for (SelectItem selectItem : selectItems) {
|
||||||
selectItem.accept(visitor);
|
selectItem.accept(visitor);
|
||||||
}
|
}
|
||||||
boolean selectFunction = visitor.hasAggregateFunction();
|
return visitor.getFunctionNames();
|
||||||
if (selectFunction) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Function getFunction(Expression expression, Map<String, String> fieldNameToAggregate) {
|
public static Function getFunction(Expression expression, Map<String, String> fieldNameToAggregate) {
|
||||||
|
|||||||
@@ -43,4 +43,34 @@ class SqlParserSelectFunctionHelperTest {
|
|||||||
Assert.assertEquals(hasAggregateFunction, true);
|
Assert.assertEquals(hasAggregateFunction, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void hasFunction() throws JSQLParserException {
|
||||||
|
|
||||||
|
String sql = "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' "
|
||||||
|
+ "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1";
|
||||||
|
boolean hasFunction = SqlParserSelectFunctionHelper.hasFunction(sql, "sum");
|
||||||
|
|
||||||
|
Assert.assertEquals(hasFunction, true);
|
||||||
|
sql = "select 部门,count (访问次数) from 超音数 where 数据日期 = '2023-08-08' "
|
||||||
|
+ "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1";
|
||||||
|
hasFunction = SqlParserSelectFunctionHelper.hasFunction(sql, "count");
|
||||||
|
Assert.assertEquals(hasFunction, true);
|
||||||
|
|
||||||
|
sql = "select 部门,count (*) from 超音数 where 数据日期 = '2023-08-08' "
|
||||||
|
+ "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1";
|
||||||
|
hasFunction = SqlParserSelectFunctionHelper.hasFunction(sql, "count");
|
||||||
|
Assert.assertEquals(hasFunction, true);
|
||||||
|
|
||||||
|
sql = "SELECT user_name, pv FROM t_34 WHERE sys_imp_date <= '2023-09-03' "
|
||||||
|
+ "AND sys_imp_date >= '2023-08-04' GROUP BY user_name ORDER BY sum(pv) DESC LIMIT 10";
|
||||||
|
hasFunction = SqlParserSelectFunctionHelper.hasFunction(sql, "sum");
|
||||||
|
Assert.assertEquals(hasFunction, false);
|
||||||
|
|
||||||
|
sql = "select 部门,min (访问次数) from 超音数 where 数据日期 = '2023-08-08' "
|
||||||
|
+ "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1";
|
||||||
|
hasFunction = SqlParserSelectFunctionHelper.hasFunction(sql, "min");
|
||||||
|
|
||||||
|
Assert.assertEquals(hasFunction, true);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.semantic.query.parser.convert;
|
|||||||
|
|
||||||
import com.tencent.supersonic.common.util.DateUtils;
|
import com.tencent.supersonic.common.util.DateUtils;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||||
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
|
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
|
||||||
import com.tencent.supersonic.semantic.api.model.pojo.SchemaItem;
|
import com.tencent.supersonic.semantic.api.model.pojo.SchemaItem;
|
||||||
@@ -81,10 +82,7 @@ public class QueryReqConverter {
|
|||||||
queryStructUtils.generateInternalMetricName(databaseReq.getModelId(),
|
queryStructUtils.generateInternalMetricName(databaseReq.getModelId(),
|
||||||
metricTable.getDimensions()))));
|
metricTable.getDimensions()))));
|
||||||
}
|
}
|
||||||
// if there is no group by in dsl,set MetricTable's aggOption to "NATIVE"
|
metricTable.setAggOption(getAggOption(databaseReq));
|
||||||
if (!SqlParserSelectHelper.hasGroupBy(databaseReq.getSql())) {
|
|
||||||
metricTable.setAggOption(AggOption.NATIVE);
|
|
||||||
}
|
|
||||||
List<MetricTable> tables = new ArrayList<>();
|
List<MetricTable> tables = new ArrayList<>();
|
||||||
tables.add(metricTable);
|
tables.add(metricTable);
|
||||||
//4.build ParseSqlReq
|
//4.build ParseSqlReq
|
||||||
@@ -104,6 +102,17 @@ public class QueryReqConverter {
|
|||||||
return queryStatement;
|
return queryStatement;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private AggOption getAggOption(QueryDslReq databaseReq) {
|
||||||
|
// if there is no group by in dsl,set MetricTable's aggOption to "NATIVE"
|
||||||
|
// if there is count() in dsl,set MetricTable's aggOption to "NATIVE"
|
||||||
|
String sql = databaseReq.getSql();
|
||||||
|
if (!SqlParserSelectHelper.hasGroupBy(sql)
|
||||||
|
|| SqlParserSelectFunctionHelper.hasFunction(sql, "count")) {
|
||||||
|
return AggOption.NATIVE;
|
||||||
|
}
|
||||||
|
return AggOption.DEFAULT;
|
||||||
|
}
|
||||||
|
|
||||||
private void convertNameToBizName(QueryDslReq databaseReq, ModelSchemaResp modelSchemaResp) {
|
private void convertNameToBizName(QueryDslReq databaseReq, ModelSchemaResp modelSchemaResp) {
|
||||||
Map<String, String> fieldNameToBizNameMap = getFieldNameToBizNameMap(modelSchemaResp);
|
Map<String, String> fieldNameToBizNameMap = getFieldNameToBizNameMap(modelSchemaResp);
|
||||||
String sql = databaseReq.getSql();
|
String sql = databaseReq.getSql();
|
||||||
|
|||||||
@@ -61,8 +61,7 @@ public class QueryController {
|
|||||||
|
|
||||||
@PostMapping("/queryStatement")
|
@PostMapping("/queryStatement")
|
||||||
public Object queryStatement(@RequestBody QueryStatement queryStatement) throws Exception {
|
public Object queryStatement(@RequestBody QueryStatement queryStatement) throws Exception {
|
||||||
Object result = queryService.queryByQueryStatement(queryStatement);
|
return queryService.queryByQueryStatement(queryStatement);
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@PostMapping("/struct/parse")
|
@PostMapping("/struct/parse")
|
||||||
|
|||||||
@@ -89,8 +89,7 @@ public class QueryServiceImpl implements QueryService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public Object queryByQueryStatement(QueryStatement queryStatement) {
|
public Object queryByQueryStatement(QueryStatement queryStatement) {
|
||||||
QueryResultWithSchemaResp results = semanticQueryEngine.execute(queryStatement);
|
return semanticQueryEngine.execute(queryStatement);
|
||||||
return results;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private QueryStatement convertToQueryStatement(QueryDslReq querySqlCmd, User user) throws Exception {
|
private QueryStatement convertToQueryStatement(QueryDslReq querySqlCmd, User user) throws Exception {
|
||||||
|
|||||||
Reference in New Issue
Block a user