[improvement](chat) Added query type parsing, supporting metrics, entities, and other query types. (#392)

This commit is contained in:
lexluo09
2023-11-16 17:08:56 +08:00
committed by GitHub
parent 8c65ac80b5
commit 46a9e5b097
8 changed files with 139 additions and 9 deletions

View File

@@ -0,0 +1,19 @@
package com.tencent.supersonic.chat.api.pojo;
/***
* Query Type
*/
public enum QueryType {
/**
* queries with metrics included in the select statement
*/
METRIC,
/**
* queries with entity unique key included in the select statement
*/
ENTITY,
/**
* the other queries
*/
OTHER
}

View File

@@ -40,6 +40,7 @@ public class SemanticParseInfo {
private Map<String, Object> properties = new HashMap<>();
private EntityInfo entityInfo;
private SqlInfo sqlInfo = new SqlInfo();
private QueryType queryType = QueryType.OTHER;
public Long getModelId() {
return model != null ? model.getId() : 0L;

View File

@@ -0,0 +1,85 @@
package com.tencent.supersonic.chat.parser;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.QueryType;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
/**
* Query type parser, determine if the query is a metric query, a entity query,
* or another type of query.
*/
@Slf4j
public class QueryTypeParser implements SemanticParser {
@Override
public void parse(QueryContext queryContext, ChatContext chatContext) {
List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries();
User user = queryContext.getRequest().getUser();
for (SemanticQuery semanticQuery : candidateQueries) {
// 1.init S2SQL
semanticQuery.initS2Sql(user);
// 2.set queryType
QueryType queryType = getQueryType(user, semanticQuery);
semanticQuery.getParseInfo().setQueryType(queryType);
}
}
private QueryType getQueryType(User user, SemanticQuery semanticQuery) {
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
SqlInfo sqlInfo = parseInfo.getSqlInfo();
if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getS2SQL())) {
return QueryType.OTHER;
}
//1. entity queryType
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof S2SQLQuery) {
// get primaryEntityBizName
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, user);
String primaryEntityBizName = semanticService.getPrimaryEntityBizName(entityInfo);
if (StringUtils.isNotEmpty(primaryEntityBizName)) {
//if exist primaryEntityBizName in parseInfo's dimensions, set nativeQuery to true
boolean existPrimaryEntityBizName = parseInfo.getDimensions().stream()
.anyMatch(schemaElement -> primaryEntityBizName.equalsIgnoreCase(schemaElement.getBizName()));
if (existPrimaryEntityBizName) {
return QueryType.ENTITY;
}
}
}
//2. metric queryType
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
List<SchemaElement> metrics = semanticSchema.getMetrics(parseInfo.getModelId());
if (CollectionUtils.isNotEmpty(metrics)) {
Set<String> metricNameSet = metrics.stream().map(metric -> metric.getName()).collect(Collectors.toSet());
boolean containMetric = selectFields.stream().anyMatch(selectField -> metricNameSet.contains(selectField));
if (containMetric) {
return QueryType.METRIC;
}
}
return QueryType.OTHER;
}
}

View File

@@ -256,11 +256,6 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
return queryResult;
}
@Override
public SemanticParseInfo getParseInfo() {
return parseInfo;
}
@Override
public void setParseInfo(SemanticParseInfo parseInfo) {
this.parseInfo = parseInfo;

View File

@@ -148,8 +148,7 @@ public class QueryServiceImpl implements QueryService {
List<SemanticQuery> candidateQueries = queryCtx.getCandidateQueries();
if (CollectionUtils.isNotEmpty(candidateQueries)) {
for (SemanticQuery semanticQuery : candidateQueries) {
semanticQuery.initS2Sql(queryReq.getUser());
// rule
// the rules are not being corrected.
if (semanticQuery instanceof RuleSemanticQuery) {
continue;
}

View File

@@ -36,6 +36,34 @@ class SqlParserEqualHelperTest {
sql2 = "SELECT d,c,b,f FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
Assert.equals(SqlParserEqualHelper.equals(sql1, sql2), false);
sql1 = "SELECT\n"
+ "页面,\n"
+ "SUM(访问次数)\n"
+ "FROM\n"
+ "超音数\n"
+ "WHERE\n"
+ "数据日期 >= '2023-10-26'\n"
+ "AND 数据日期 <= '2023-11-09'\n"
+ "AND department = \"HR\"\n"
+ "GROUP BY\n"
+ "页面\n"
+ "LIMIT\n"
+ "365";
sql2 = "SELECT\n"
+ "页面,\n"
+ "SUM(访问次数)\n"
+ "FROM\n"
+ "超音数\n"
+ "WHERE\n"
+ "department = \"HR\"\n"
+ "AND 数据日期 >= '2023-10-26'\n"
+ "AND 数据日期 <= '2023-11-09'\n"
+ "GROUP BY\n"
+ "页面\n"
+ "LIMIT\n"
+ "365";
Assert.equals(SqlParserEqualHelper.equals(sql1, sql2), true);
}
}

View File

@@ -8,7 +8,9 @@ com.tencent.supersonic.chat.api.component.SchemaMapper=\
com.tencent.supersonic.chat.api.component.SemanticParser=\
com.tencent.supersonic.chat.parser.rule.RuleBasedParser, \
com.tencent.supersonic.chat.parser.llm.s2sql.LLMS2SQLParser, \
com.tencent.supersonic.chat.parser.plugin.function.FunctionBasedParser
com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingBasedParser, \
com.tencent.supersonic.chat.parser.plugin.function.FunctionBasedParser, \
com.tencent.supersonic.chat.parser.QueryTypeParser
com.tencent.supersonic.chat.api.component.SemanticCorrector=\
com.tencent.supersonic.chat.corrector.SchemaCorrector, \

View File

@@ -9,7 +9,8 @@ com.tencent.supersonic.chat.api.component.SemanticParser=\
com.tencent.supersonic.chat.parser.rule.RuleBasedParser, \
com.tencent.supersonic.chat.parser.llm.s2sql.LLMS2SQLParser, \
com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingBasedParser, \
com.tencent.supersonic.chat.parser.plugin.function.FunctionBasedParser
com.tencent.supersonic.chat.parser.plugin.function.FunctionBasedParser, \
com.tencent.supersonic.chat.parser.QueryTypeParser
com.tencent.supersonic.chat.api.component.SemanticCorrector=\
com.tencent.supersonic.chat.corrector.SchemaCorrector, \