From f605cf0ef99631712375460af7b2c8d71ef95163 Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Fri, 13 Oct 2023 14:25:24 +0800 Subject: [PATCH] (improvment)(chat) if exist count() in dsl,set query to NATIVE and only order by field and group by field can add to select (#206) --- .../chat/corrector/BaseSemanticCorrector.java | 12 ++++---- .../execute/EntityInfoExecuteResponder.java | 4 ++- .../parse/EntityInfoParseResponder.java | 2 +- .../chat/service/impl/ChatServiceImpl.java | 12 ++++---- .../jsqlparser/AggregateFunctionVisitor.java | 19 ------------ .../util/jsqlparser/FunctionVisitor.java | 21 +++++++++++++ .../SqlParserSelectFunctionHelper.java | 25 ++++++++++------ .../SqlParserSelectFunctionHelperTest.java | 30 +++++++++++++++++++ .../parser/convert/QueryReqConverter.java | 17 ++++++++--- .../semantic/query/rest/QueryController.java | 3 +- .../query/service/QueryServiceImpl.java | 3 +- 11 files changed, 99 insertions(+), 49 deletions(-) delete mode 100644 common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/AggregateFunctionVisitor.java create mode 100644 common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FunctionVisitor.java diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java index 455643938..ddd49f2c9 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java @@ -44,16 +44,16 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector { protected void addFieldsToSelect(SemanticCorrectInfo semanticCorrectInfo, String sql) { Set selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(sql)); - Set whereFields = new HashSet<>(SqlParserSelectHelper.getWhereFields(sql)); + Set needAddFields = new HashSet<>(SqlParserSelectHelper.getGroupByFields(sql)); + needAddFields.addAll(SqlParserSelectHelper.getOrderByFields(sql)); - if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(whereFields)) { + if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(needAddFields)) { return; } - whereFields.addAll(SqlParserSelectHelper.getOrderByFields(sql)); - whereFields.removeAll(selectFields); - whereFields.remove(DateUtils.DATE_FIELD); - String replaceFields = SqlParserAddHelper.addFieldsToSelect(sql, new ArrayList<>(whereFields)); + needAddFields.removeAll(selectFields); + needAddFields.remove(DateUtils.DATE_FIELD); + String replaceFields = SqlParserAddHelper.addFieldsToSelect(sql, new ArrayList<>(needAddFields)); semanticCorrectInfo.setSql(replaceFields); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/execute/EntityInfoExecuteResponder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/execute/EntityInfoExecuteResponder.java index c89137af5..d658aee0f 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/execute/EntityInfoExecuteResponder.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/execute/EntityInfoExecuteResponder.java @@ -6,6 +6,7 @@ import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq; import com.tencent.supersonic.chat.api.pojo.response.EntityInfo; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.query.QueryManager; +import com.tencent.supersonic.chat.query.llm.dsl.DslQuery; import com.tencent.supersonic.chat.service.SemanticService; import com.tencent.supersonic.common.util.ContextUtils; import java.util.List; @@ -22,7 +23,8 @@ public class EntityInfoExecuteResponder implements ExecuteResponder { if (semanticParseInfo == null || semanticParseInfo.getModelId() <= 0L) { return; } - if (QueryManager.isPluginQuery(semanticParseInfo.getQueryMode())) { + String queryMode = semanticParseInfo.getQueryMode(); + if (QueryManager.isPluginQuery(queryMode) && !DslQuery.QUERY_MODE.equals(queryMode)) { return; } SemanticService semanticService = ContextUtils.getBean(SemanticService.class); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/EntityInfoParseResponder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/EntityInfoParseResponder.java index dc9f570b8..15a8e041a 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/EntityInfoParseResponder.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/EntityInfoParseResponder.java @@ -24,7 +24,7 @@ public class EntityInfoParseResponder implements ParseResponder { QueryReq queryReq = queryContext.getRequest(); selectedParses.forEach(parseInfo -> { if (QueryManager.isPluginQuery(parseInfo.getQueryMode()) - && !parseInfo.getQueryMode().equals(DslQuery.QUERY_MODE)) { + && !DslQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) { return; } //1. set entity info diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/ChatServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/ChatServiceImpl.java index e374b635d..4d985a654 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/ChatServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/ChatServiceImpl.java @@ -46,7 +46,7 @@ public class ChatServiceImpl implements ChatService { private SolvedQueryManager solvedQueryManager; public ChatServiceImpl(ChatContextRepository chatContextRepository, ChatRepository chatRepository, - ChatQueryRepository chatQueryRepository, SolvedQueryManager solvedQueryManager) { + ChatQueryRepository chatQueryRepository, SolvedQueryManager solvedQueryManager) { this.chatContextRepository = chatContextRepository; this.chatRepository = chatRepository; this.chatQueryRepository = chatQueryRepository; @@ -174,9 +174,9 @@ public class ChatServiceImpl implements ChatService { @Override public void batchAddParse(ChatContext chatCtx, QueryReq queryReq, - ParseResp parseResult, - List candidateParses, - List selectedParses) { + ParseResp parseResult, + List candidateParses, + List selectedParses) { chatQueryRepository.batchSaveParseInfo(chatCtx, queryReq, parseResult, candidateParses, selectedParses); } @@ -205,6 +205,8 @@ public class ChatServiceImpl implements ChatService { List solvedQueryRecallResps = solvedQueryManager.recallSolvedQuery(queryText, agentId); List queryIds = solvedQueryRecallResps.stream() .map(SolvedQueryRecallResp::getQueryId).collect(Collectors.toList()); + List queryIds = solvedQueryRecallResps.stream().map(SolvedQueryRecallResp::getQueryId) + .collect(Collectors.toList()); PageQueryInfoReq pageQueryInfoReq = new PageQueryInfoReq(); pageQueryInfoReq.setIds(queryIds); pageQueryInfoReq.setPageSize(100); @@ -219,7 +221,7 @@ public class ChatServiceImpl implements ChatService { queryResp.getScore() != null && queryResp.getScore() <= lowScoreThreshold) .map(QueryResp::getQuestionId).collect(Collectors.toSet()); return solvedQueryRecallResps.stream().filter(solvedQueryRecallResp -> - !lowScoreQueryIds.contains(solvedQueryRecallResp.getQueryId())) + !lowScoreQueryIds.contains(solvedQueryRecallResp.getQueryId())) .collect(Collectors.toList()); } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/AggregateFunctionVisitor.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/AggregateFunctionVisitor.java deleted file mode 100644 index f9b9d5ac1..000000000 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/AggregateFunctionVisitor.java +++ /dev/null @@ -1,19 +0,0 @@ -package com.tencent.supersonic.common.util.jsqlparser; - -import net.sf.jsqlparser.expression.ExpressionVisitorAdapter; -import net.sf.jsqlparser.expression.Function; - -public class AggregateFunctionVisitor extends ExpressionVisitorAdapter { - - private boolean hasAggregateFunction = false; - - public boolean hasAggregateFunction() { - return hasAggregateFunction; - } - - @Override - public void visit(Function function) { - super.visit(function); - hasAggregateFunction = true; - } -} \ No newline at end of file diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FunctionVisitor.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FunctionVisitor.java new file mode 100644 index 000000000..16ee86882 --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FunctionVisitor.java @@ -0,0 +1,21 @@ +package com.tencent.supersonic.common.util.jsqlparser; + +import java.util.HashSet; +import java.util.Set; +import net.sf.jsqlparser.expression.ExpressionVisitorAdapter; +import net.sf.jsqlparser.expression.Function; + +public class FunctionVisitor extends ExpressionVisitorAdapter { + + private Set functionNames = new HashSet<>(); + + public Set getFunctionNames() { + return functionNames; + } + + @Override + public void visit(Function function) { + super.visit(function); + functionNames.add(function.getName()); + } +} \ No newline at end of file diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectFunctionHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectFunctionHelper.java index c07e2ee2e..4f54a7470 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectFunctionHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectFunctionHelper.java @@ -1,8 +1,10 @@ package com.tencent.supersonic.common.util.jsqlparser; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Function; @@ -13,6 +15,7 @@ import net.sf.jsqlparser.statement.select.Select; import net.sf.jsqlparser.statement.select.SelectBody; import net.sf.jsqlparser.statement.select.SelectItem; import org.apache.commons.lang3.StringUtils; +import org.springframework.util.CollectionUtils; /** * Sql Parser Select function Helper @@ -21,30 +24,34 @@ import org.apache.commons.lang3.StringUtils; public class SqlParserSelectFunctionHelper { public static boolean hasAggregateFunction(String sql) { - if (hasFunction(sql)) { + if (!CollectionUtils.isEmpty(getFunctions(sql))) { return true; } return SqlParserSelectHelper.hasGroupBy(sql); } - public static boolean hasFunction(String sql) { + public static boolean hasFunction(String sql, String functionName) { + Set functions = getFunctions(sql); + if (!CollectionUtils.isEmpty(functions)) { + return functions.stream().anyMatch(function -> function.equalsIgnoreCase(functionName)); + } + return false; + } + + public static Set getFunctions(String sql) { Select selectStatement = SqlParserSelectHelper.getSelect(sql); SelectBody selectBody = selectStatement.getSelectBody(); if (!(selectBody instanceof PlainSelect)) { - return false; + return new HashSet<>(); } PlainSelect plainSelect = (PlainSelect) selectBody; List selectItems = plainSelect.getSelectItems(); - AggregateFunctionVisitor visitor = new AggregateFunctionVisitor(); + FunctionVisitor visitor = new FunctionVisitor(); for (SelectItem selectItem : selectItems) { selectItem.accept(visitor); } - boolean selectFunction = visitor.hasAggregateFunction(); - if (selectFunction) { - return true; - } - return false; + return visitor.getFunctionNames(); } public static Function getFunction(Expression expression, Map fieldNameToAggregate) { diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectFunctionHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectFunctionHelperTest.java index 0887e7921..c25b50806 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectFunctionHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectFunctionHelperTest.java @@ -43,4 +43,34 @@ class SqlParserSelectFunctionHelperTest { Assert.assertEquals(hasAggregateFunction, true); } + + @Test + void hasFunction() throws JSQLParserException { + + String sql = "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' " + + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1"; + boolean hasFunction = SqlParserSelectFunctionHelper.hasFunction(sql, "sum"); + + Assert.assertEquals(hasFunction, true); + sql = "select 部门,count (访问次数) from 超音数 where 数据日期 = '2023-08-08' " + + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1"; + hasFunction = SqlParserSelectFunctionHelper.hasFunction(sql, "count"); + Assert.assertEquals(hasFunction, true); + + sql = "select 部门,count (*) from 超音数 where 数据日期 = '2023-08-08' " + + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1"; + hasFunction = SqlParserSelectFunctionHelper.hasFunction(sql, "count"); + Assert.assertEquals(hasFunction, true); + + sql = "SELECT user_name, pv FROM t_34 WHERE sys_imp_date <= '2023-09-03' " + + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name ORDER BY sum(pv) DESC LIMIT 10"; + hasFunction = SqlParserSelectFunctionHelper.hasFunction(sql, "sum"); + Assert.assertEquals(hasFunction, false); + + sql = "select 部门,min (访问次数) from 超音数 where 数据日期 = '2023-08-08' " + + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1"; + hasFunction = SqlParserSelectFunctionHelper.hasFunction(sql, "min"); + + Assert.assertEquals(hasFunction, true); + } } diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/convert/QueryReqConverter.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/convert/QueryReqConverter.java index aa4e297a7..ca7fccd92 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/convert/QueryReqConverter.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/convert/QueryReqConverter.java @@ -2,6 +2,7 @@ package com.tencent.supersonic.semantic.query.parser.convert; import com.tencent.supersonic.common.util.DateUtils; import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum; import com.tencent.supersonic.semantic.api.model.pojo.SchemaItem; @@ -81,10 +82,7 @@ public class QueryReqConverter { queryStructUtils.generateInternalMetricName(databaseReq.getModelId(), metricTable.getDimensions())))); } - // if there is no group by in dsl,set MetricTable's aggOption to "NATIVE" - if (!SqlParserSelectHelper.hasGroupBy(databaseReq.getSql())) { - metricTable.setAggOption(AggOption.NATIVE); - } + metricTable.setAggOption(getAggOption(databaseReq)); List tables = new ArrayList<>(); tables.add(metricTable); //4.build ParseSqlReq @@ -104,6 +102,17 @@ public class QueryReqConverter { return queryStatement; } + private AggOption getAggOption(QueryDslReq databaseReq) { + // if there is no group by in dsl,set MetricTable's aggOption to "NATIVE" + // if there is count() in dsl,set MetricTable's aggOption to "NATIVE" + String sql = databaseReq.getSql(); + if (!SqlParserSelectHelper.hasGroupBy(sql) + || SqlParserSelectFunctionHelper.hasFunction(sql, "count")) { + return AggOption.NATIVE; + } + return AggOption.DEFAULT; + } + private void convertNameToBizName(QueryDslReq databaseReq, ModelSchemaResp modelSchemaResp) { Map fieldNameToBizNameMap = getFieldNameToBizNameMap(modelSchemaResp); String sql = databaseReq.getSql(); diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/rest/QueryController.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/rest/QueryController.java index b7f215c56..b0efe12b3 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/rest/QueryController.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/rest/QueryController.java @@ -61,8 +61,7 @@ public class QueryController { @PostMapping("/queryStatement") public Object queryStatement(@RequestBody QueryStatement queryStatement) throws Exception { - Object result = queryService.queryByQueryStatement(queryStatement); - return result; + return queryService.queryByQueryStatement(queryStatement); } @PostMapping("/struct/parse") diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryServiceImpl.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryServiceImpl.java index 3e98778b9..69c13d55a 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryServiceImpl.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryServiceImpl.java @@ -89,8 +89,7 @@ public class QueryServiceImpl implements QueryService { } public Object queryByQueryStatement(QueryStatement queryStatement) { - QueryResultWithSchemaResp results = semanticQueryEngine.execute(queryStatement); - return results; + return semanticQueryEngine.execute(queryStatement); } private QueryStatement convertToQueryStatement(QueryDslReq querySqlCmd, User user) throws Exception {