From 46a9e5b0971d708ecac6e0eee167e1771e3139f4 Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Thu, 16 Nov 2023 17:08:56 +0800 Subject: [PATCH] [improvement](chat) Added query type parsing, supporting metrics, entities, and other query types. (#392) --- .../supersonic/chat/api/pojo/QueryType.java | 19 +++++ .../chat/api/pojo/SemanticParseInfo.java | 1 + .../chat/parser/QueryTypeParser.java | 85 +++++++++++++++++++ .../chat/query/rule/RuleSemanticQuery.java | 5 -- .../chat/service/impl/QueryServiceImpl.java | 3 +- .../jsqlparser/SqlParserEqualHelperTest.java | 28 ++++++ .../main/resources/META-INF/spring.factories | 4 +- .../main/resources/META-INF/spring.factories | 3 +- 8 files changed, 139 insertions(+), 9 deletions(-) create mode 100644 chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/QueryType.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/parser/QueryTypeParser.java diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/QueryType.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/QueryType.java new file mode 100644 index 000000000..191961714 --- /dev/null +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/QueryType.java @@ -0,0 +1,19 @@ +package com.tencent.supersonic.chat.api.pojo; + +/*** + * Query Type + */ +public enum QueryType { + /** + * queries with metrics included in the select statement + */ + METRIC, + /** + * queries with entity unique key included in the select statement + */ + ENTITY, + /** + * the other queries + */ + OTHER +} diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/SemanticParseInfo.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/SemanticParseInfo.java index 79e8df5cd..57aeb2e39 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/SemanticParseInfo.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/SemanticParseInfo.java @@ -40,6 +40,7 @@ public class SemanticParseInfo { private Map properties = new HashMap<>(); private EntityInfo entityInfo; private SqlInfo sqlInfo = new SqlInfo(); + private QueryType queryType = QueryType.OTHER; public Long getModelId() { return model != null ? model.getId() : 0L; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/QueryTypeParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/QueryTypeParser.java new file mode 100644 index 000000000..42b27309a --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/QueryTypeParser.java @@ -0,0 +1,85 @@ +package com.tencent.supersonic.chat.parser; + +import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.chat.api.component.SemanticParser; +import com.tencent.supersonic.chat.api.component.SemanticQuery; +import com.tencent.supersonic.chat.api.pojo.ChatContext; +import com.tencent.supersonic.chat.api.pojo.QueryContext; +import com.tencent.supersonic.chat.api.pojo.QueryType; +import com.tencent.supersonic.chat.api.pojo.SchemaElement; +import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.chat.api.pojo.SemanticSchema; +import com.tencent.supersonic.chat.api.pojo.response.EntityInfo; +import com.tencent.supersonic.chat.api.pojo.response.SqlInfo; +import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery; +import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery; +import com.tencent.supersonic.chat.service.SemanticService; +import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; +import com.tencent.supersonic.knowledge.service.SchemaService; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.lang3.StringUtils; + +/** + * Query type parser, determine if the query is a metric query, a entity query, + * or another type of query. + */ +@Slf4j +public class QueryTypeParser implements SemanticParser { + + @Override + public void parse(QueryContext queryContext, ChatContext chatContext) { + + List candidateQueries = queryContext.getCandidateQueries(); + User user = queryContext.getRequest().getUser(); + + for (SemanticQuery semanticQuery : candidateQueries) { + // 1.init S2SQL + semanticQuery.initS2Sql(user); + // 2.set queryType + QueryType queryType = getQueryType(user, semanticQuery); + semanticQuery.getParseInfo().setQueryType(queryType); + } + } + + private QueryType getQueryType(User user, SemanticQuery semanticQuery) { + SemanticParseInfo parseInfo = semanticQuery.getParseInfo(); + SqlInfo sqlInfo = parseInfo.getSqlInfo(); + if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getS2SQL())) { + return QueryType.OTHER; + } + //1. entity queryType + if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof S2SQLQuery) { + // get primaryEntityBizName + SemanticService semanticService = ContextUtils.getBean(SemanticService.class); + EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, user); + String primaryEntityBizName = semanticService.getPrimaryEntityBizName(entityInfo); + if (StringUtils.isNotEmpty(primaryEntityBizName)) { + //if exist primaryEntityBizName in parseInfo's dimensions, set nativeQuery to true + boolean existPrimaryEntityBizName = parseInfo.getDimensions().stream() + .anyMatch(schemaElement -> primaryEntityBizName.equalsIgnoreCase(schemaElement.getBizName())); + if (existPrimaryEntityBizName) { + return QueryType.ENTITY; + } + } + } + //2. metric queryType + List selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL()); + SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); + List metrics = semanticSchema.getMetrics(parseInfo.getModelId()); + if (CollectionUtils.isNotEmpty(metrics)) { + Set metricNameSet = metrics.stream().map(metric -> metric.getName()).collect(Collectors.toSet()); + boolean containMetric = selectFields.stream().anyMatch(selectField -> metricNameSet.contains(selectField)); + if (containMetric) { + return QueryType.METRIC; + } + } + return QueryType.OTHER; + } + +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/RuleSemanticQuery.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/RuleSemanticQuery.java index 76cf78637..b24035d64 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/RuleSemanticQuery.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/RuleSemanticQuery.java @@ -256,11 +256,6 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery { return queryResult; } - @Override - public SemanticParseInfo getParseInfo() { - return parseInfo; - } - @Override public void setParseInfo(SemanticParseInfo parseInfo) { this.parseInfo = parseInfo; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java index cf64afd05..134d5c548 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java @@ -148,8 +148,7 @@ public class QueryServiceImpl implements QueryService { List candidateQueries = queryCtx.getCandidateQueries(); if (CollectionUtils.isNotEmpty(candidateQueries)) { for (SemanticQuery semanticQuery : candidateQueries) { - semanticQuery.initS2Sql(queryReq.getUser()); - // rule + // the rules are not being corrected. if (semanticQuery instanceof RuleSemanticQuery) { continue; } diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserEqualHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserEqualHelperTest.java index 3995a0f55..104c87e29 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserEqualHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserEqualHelperTest.java @@ -36,6 +36,34 @@ class SqlParserEqualHelperTest { sql2 = "SELECT d,c,b,f FROM table1 WHERE column2 = 2 AND column1 = 1 order by a"; Assert.equals(SqlParserEqualHelper.equals(sql1, sql2), false); + sql1 = "SELECT\n" + + "页面,\n" + + "SUM(访问次数)\n" + + "FROM\n" + + "超音数\n" + + "WHERE\n" + + "数据日期 >= '2023-10-26'\n" + + "AND 数据日期 <= '2023-11-09'\n" + + "AND department = \"HR\"\n" + + "GROUP BY\n" + + "页面\n" + + "LIMIT\n" + + "365"; + sql2 = "SELECT\n" + + "页面,\n" + + "SUM(访问次数)\n" + + "FROM\n" + + "超音数\n" + + "WHERE\n" + + "department = \"HR\"\n" + + "AND 数据日期 >= '2023-10-26'\n" + + "AND 数据日期 <= '2023-11-09'\n" + + "GROUP BY\n" + + "页面\n" + + "LIMIT\n" + + "365"; + + Assert.equals(SqlParserEqualHelper.equals(sql1, sql2), true); } } \ No newline at end of file diff --git a/launchers/chat/src/main/resources/META-INF/spring.factories b/launchers/chat/src/main/resources/META-INF/spring.factories index 28f1a0e2f..01731b73f 100644 --- a/launchers/chat/src/main/resources/META-INF/spring.factories +++ b/launchers/chat/src/main/resources/META-INF/spring.factories @@ -8,7 +8,9 @@ com.tencent.supersonic.chat.api.component.SchemaMapper=\ com.tencent.supersonic.chat.api.component.SemanticParser=\ com.tencent.supersonic.chat.parser.rule.RuleBasedParser, \ com.tencent.supersonic.chat.parser.llm.s2sql.LLMS2SQLParser, \ - com.tencent.supersonic.chat.parser.plugin.function.FunctionBasedParser + com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingBasedParser, \ + com.tencent.supersonic.chat.parser.plugin.function.FunctionBasedParser, \ + com.tencent.supersonic.chat.parser.QueryTypeParser com.tencent.supersonic.chat.api.component.SemanticCorrector=\ com.tencent.supersonic.chat.corrector.SchemaCorrector, \ diff --git a/launchers/standalone/src/main/resources/META-INF/spring.factories b/launchers/standalone/src/main/resources/META-INF/spring.factories index 531fb6bb7..66e211d18 100644 --- a/launchers/standalone/src/main/resources/META-INF/spring.factories +++ b/launchers/standalone/src/main/resources/META-INF/spring.factories @@ -9,7 +9,8 @@ com.tencent.supersonic.chat.api.component.SemanticParser=\ com.tencent.supersonic.chat.parser.rule.RuleBasedParser, \ com.tencent.supersonic.chat.parser.llm.s2sql.LLMS2SQLParser, \ com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingBasedParser, \ - com.tencent.supersonic.chat.parser.plugin.function.FunctionBasedParser + com.tencent.supersonic.chat.parser.plugin.function.FunctionBasedParser, \ + com.tencent.supersonic.chat.parser.QueryTypeParser com.tencent.supersonic.chat.api.component.SemanticCorrector=\ com.tencent.supersonic.chat.corrector.SchemaCorrector, \