mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
[improvement](chat) Added query type parsing, supporting metrics, entities, and other query types. (#392)
This commit is contained in:
@@ -0,0 +1,19 @@
|
||||
package com.tencent.supersonic.chat.api.pojo;
|
||||
|
||||
/***
|
||||
* Query Type
|
||||
*/
|
||||
public enum QueryType {
|
||||
/**
|
||||
* queries with metrics included in the select statement
|
||||
*/
|
||||
METRIC,
|
||||
/**
|
||||
* queries with entity unique key included in the select statement
|
||||
*/
|
||||
ENTITY,
|
||||
/**
|
||||
* the other queries
|
||||
*/
|
||||
OTHER
|
||||
}
|
||||
@@ -40,6 +40,7 @@ public class SemanticParseInfo {
|
||||
private Map<String, Object> properties = new HashMap<>();
|
||||
private EntityInfo entityInfo;
|
||||
private SqlInfo sqlInfo = new SqlInfo();
|
||||
private QueryType queryType = QueryType.OTHER;
|
||||
|
||||
public Long getModelId() {
|
||||
return model != null ? model.getId() : 0L;
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
package com.tencent.supersonic.chat.parser;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
/**
|
||||
* Query type parser, determine if the query is a metric query, a entity query,
|
||||
* or another type of query.
|
||||
*/
|
||||
@Slf4j
|
||||
public class QueryTypeParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
|
||||
List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries();
|
||||
User user = queryContext.getRequest().getUser();
|
||||
|
||||
for (SemanticQuery semanticQuery : candidateQueries) {
|
||||
// 1.init S2SQL
|
||||
semanticQuery.initS2Sql(user);
|
||||
// 2.set queryType
|
||||
QueryType queryType = getQueryType(user, semanticQuery);
|
||||
semanticQuery.getParseInfo().setQueryType(queryType);
|
||||
}
|
||||
}
|
||||
|
||||
private QueryType getQueryType(User user, SemanticQuery semanticQuery) {
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getS2SQL())) {
|
||||
return QueryType.OTHER;
|
||||
}
|
||||
//1. entity queryType
|
||||
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof S2SQLQuery) {
|
||||
// get primaryEntityBizName
|
||||
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
||||
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, user);
|
||||
String primaryEntityBizName = semanticService.getPrimaryEntityBizName(entityInfo);
|
||||
if (StringUtils.isNotEmpty(primaryEntityBizName)) {
|
||||
//if exist primaryEntityBizName in parseInfo's dimensions, set nativeQuery to true
|
||||
boolean existPrimaryEntityBizName = parseInfo.getDimensions().stream()
|
||||
.anyMatch(schemaElement -> primaryEntityBizName.equalsIgnoreCase(schemaElement.getBizName()));
|
||||
if (existPrimaryEntityBizName) {
|
||||
return QueryType.ENTITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
//2. metric queryType
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
List<SchemaElement> metrics = semanticSchema.getMetrics(parseInfo.getModelId());
|
||||
if (CollectionUtils.isNotEmpty(metrics)) {
|
||||
Set<String> metricNameSet = metrics.stream().map(metric -> metric.getName()).collect(Collectors.toSet());
|
||||
boolean containMetric = selectFields.stream().anyMatch(selectField -> metricNameSet.contains(selectField));
|
||||
if (containMetric) {
|
||||
return QueryType.METRIC;
|
||||
}
|
||||
}
|
||||
return QueryType.OTHER;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -256,11 +256,6 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SemanticParseInfo getParseInfo() {
|
||||
return parseInfo;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setParseInfo(SemanticParseInfo parseInfo) {
|
||||
this.parseInfo = parseInfo;
|
||||
|
||||
@@ -148,8 +148,7 @@ public class QueryServiceImpl implements QueryService {
|
||||
List<SemanticQuery> candidateQueries = queryCtx.getCandidateQueries();
|
||||
if (CollectionUtils.isNotEmpty(candidateQueries)) {
|
||||
for (SemanticQuery semanticQuery : candidateQueries) {
|
||||
semanticQuery.initS2Sql(queryReq.getUser());
|
||||
// rule
|
||||
// the rules are not being corrected.
|
||||
if (semanticQuery instanceof RuleSemanticQuery) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -36,6 +36,34 @@ class SqlParserEqualHelperTest {
|
||||
sql2 = "SELECT d,c,b,f FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
||||
Assert.equals(SqlParserEqualHelper.equals(sql1, sql2), false);
|
||||
|
||||
sql1 = "SELECT\n"
|
||||
+ "页面,\n"
|
||||
+ "SUM(访问次数)\n"
|
||||
+ "FROM\n"
|
||||
+ "超音数\n"
|
||||
+ "WHERE\n"
|
||||
+ "数据日期 >= '2023-10-26'\n"
|
||||
+ "AND 数据日期 <= '2023-11-09'\n"
|
||||
+ "AND department = \"HR\"\n"
|
||||
+ "GROUP BY\n"
|
||||
+ "页面\n"
|
||||
+ "LIMIT\n"
|
||||
+ "365";
|
||||
|
||||
sql2 = "SELECT\n"
|
||||
+ "页面,\n"
|
||||
+ "SUM(访问次数)\n"
|
||||
+ "FROM\n"
|
||||
+ "超音数\n"
|
||||
+ "WHERE\n"
|
||||
+ "department = \"HR\"\n"
|
||||
+ "AND 数据日期 >= '2023-10-26'\n"
|
||||
+ "AND 数据日期 <= '2023-11-09'\n"
|
||||
+ "GROUP BY\n"
|
||||
+ "页面\n"
|
||||
+ "LIMIT\n"
|
||||
+ "365";
|
||||
|
||||
Assert.equals(SqlParserEqualHelper.equals(sql1, sql2), true);
|
||||
}
|
||||
}
|
||||
@@ -8,7 +8,9 @@ com.tencent.supersonic.chat.api.component.SchemaMapper=\
|
||||
com.tencent.supersonic.chat.api.component.SemanticParser=\
|
||||
com.tencent.supersonic.chat.parser.rule.RuleBasedParser, \
|
||||
com.tencent.supersonic.chat.parser.llm.s2sql.LLMS2SQLParser, \
|
||||
com.tencent.supersonic.chat.parser.plugin.function.FunctionBasedParser
|
||||
com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingBasedParser, \
|
||||
com.tencent.supersonic.chat.parser.plugin.function.FunctionBasedParser, \
|
||||
com.tencent.supersonic.chat.parser.QueryTypeParser
|
||||
|
||||
com.tencent.supersonic.chat.api.component.SemanticCorrector=\
|
||||
com.tencent.supersonic.chat.corrector.SchemaCorrector, \
|
||||
|
||||
@@ -9,7 +9,8 @@ com.tencent.supersonic.chat.api.component.SemanticParser=\
|
||||
com.tencent.supersonic.chat.parser.rule.RuleBasedParser, \
|
||||
com.tencent.supersonic.chat.parser.llm.s2sql.LLMS2SQLParser, \
|
||||
com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingBasedParser, \
|
||||
com.tencent.supersonic.chat.parser.plugin.function.FunctionBasedParser
|
||||
com.tencent.supersonic.chat.parser.plugin.function.FunctionBasedParser, \
|
||||
com.tencent.supersonic.chat.parser.QueryTypeParser
|
||||
|
||||
com.tencent.supersonic.chat.api.component.SemanticCorrector=\
|
||||
com.tencent.supersonic.chat.corrector.SchemaCorrector, \
|
||||
|
||||
Reference in New Issue
Block a user