mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 12:37:55 +00:00
[improvement](chat) Added query type parsing, supporting metrics, entities, and other query types. (#392)
This commit is contained in:
@@ -0,0 +1,19 @@
|
|||||||
|
package com.tencent.supersonic.chat.api.pojo;
|
||||||
|
|
||||||
|
/***
|
||||||
|
* Query Type
|
||||||
|
*/
|
||||||
|
public enum QueryType {
|
||||||
|
/**
|
||||||
|
* queries with metrics included in the select statement
|
||||||
|
*/
|
||||||
|
METRIC,
|
||||||
|
/**
|
||||||
|
* queries with entity unique key included in the select statement
|
||||||
|
*/
|
||||||
|
ENTITY,
|
||||||
|
/**
|
||||||
|
* the other queries
|
||||||
|
*/
|
||||||
|
OTHER
|
||||||
|
}
|
||||||
@@ -40,6 +40,7 @@ public class SemanticParseInfo {
|
|||||||
private Map<String, Object> properties = new HashMap<>();
|
private Map<String, Object> properties = new HashMap<>();
|
||||||
private EntityInfo entityInfo;
|
private EntityInfo entityInfo;
|
||||||
private SqlInfo sqlInfo = new SqlInfo();
|
private SqlInfo sqlInfo = new SqlInfo();
|
||||||
|
private QueryType queryType = QueryType.OTHER;
|
||||||
|
|
||||||
public Long getModelId() {
|
public Long getModelId() {
|
||||||
return model != null ? model.getId() : 0L;
|
return model != null ? model.getId() : 0L;
|
||||||
|
|||||||
@@ -0,0 +1,85 @@
|
|||||||
|
package com.tencent.supersonic.chat.parser;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
|
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||||
|
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.QueryType;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
|
||||||
|
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
||||||
|
import com.tencent.supersonic.chat.service.SemanticService;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||||
|
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections4.CollectionUtils;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Query type parser, determine if the query is a metric query, a entity query,
|
||||||
|
* or another type of query.
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class QueryTypeParser implements SemanticParser {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||||
|
|
||||||
|
List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries();
|
||||||
|
User user = queryContext.getRequest().getUser();
|
||||||
|
|
||||||
|
for (SemanticQuery semanticQuery : candidateQueries) {
|
||||||
|
// 1.init S2SQL
|
||||||
|
semanticQuery.initS2Sql(user);
|
||||||
|
// 2.set queryType
|
||||||
|
QueryType queryType = getQueryType(user, semanticQuery);
|
||||||
|
semanticQuery.getParseInfo().setQueryType(queryType);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private QueryType getQueryType(User user, SemanticQuery semanticQuery) {
|
||||||
|
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||||
|
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||||
|
if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getS2SQL())) {
|
||||||
|
return QueryType.OTHER;
|
||||||
|
}
|
||||||
|
//1. entity queryType
|
||||||
|
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof S2SQLQuery) {
|
||||||
|
// get primaryEntityBizName
|
||||||
|
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
||||||
|
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, user);
|
||||||
|
String primaryEntityBizName = semanticService.getPrimaryEntityBizName(entityInfo);
|
||||||
|
if (StringUtils.isNotEmpty(primaryEntityBizName)) {
|
||||||
|
//if exist primaryEntityBizName in parseInfo's dimensions, set nativeQuery to true
|
||||||
|
boolean existPrimaryEntityBizName = parseInfo.getDimensions().stream()
|
||||||
|
.anyMatch(schemaElement -> primaryEntityBizName.equalsIgnoreCase(schemaElement.getBizName()));
|
||||||
|
if (existPrimaryEntityBizName) {
|
||||||
|
return QueryType.ENTITY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//2. metric queryType
|
||||||
|
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||||
|
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||||
|
List<SchemaElement> metrics = semanticSchema.getMetrics(parseInfo.getModelId());
|
||||||
|
if (CollectionUtils.isNotEmpty(metrics)) {
|
||||||
|
Set<String> metricNameSet = metrics.stream().map(metric -> metric.getName()).collect(Collectors.toSet());
|
||||||
|
boolean containMetric = selectFields.stream().anyMatch(selectField -> metricNameSet.contains(selectField));
|
||||||
|
if (containMetric) {
|
||||||
|
return QueryType.METRIC;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return QueryType.OTHER;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -256,11 +256,6 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
|||||||
return queryResult;
|
return queryResult;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public SemanticParseInfo getParseInfo() {
|
|
||||||
return parseInfo;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setParseInfo(SemanticParseInfo parseInfo) {
|
public void setParseInfo(SemanticParseInfo parseInfo) {
|
||||||
this.parseInfo = parseInfo;
|
this.parseInfo = parseInfo;
|
||||||
|
|||||||
@@ -148,8 +148,7 @@ public class QueryServiceImpl implements QueryService {
|
|||||||
List<SemanticQuery> candidateQueries = queryCtx.getCandidateQueries();
|
List<SemanticQuery> candidateQueries = queryCtx.getCandidateQueries();
|
||||||
if (CollectionUtils.isNotEmpty(candidateQueries)) {
|
if (CollectionUtils.isNotEmpty(candidateQueries)) {
|
||||||
for (SemanticQuery semanticQuery : candidateQueries) {
|
for (SemanticQuery semanticQuery : candidateQueries) {
|
||||||
semanticQuery.initS2Sql(queryReq.getUser());
|
// the rules are not being corrected.
|
||||||
// rule
|
|
||||||
if (semanticQuery instanceof RuleSemanticQuery) {
|
if (semanticQuery instanceof RuleSemanticQuery) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -36,6 +36,34 @@ class SqlParserEqualHelperTest {
|
|||||||
sql2 = "SELECT d,c,b,f FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
sql2 = "SELECT d,c,b,f FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
||||||
Assert.equals(SqlParserEqualHelper.equals(sql1, sql2), false);
|
Assert.equals(SqlParserEqualHelper.equals(sql1, sql2), false);
|
||||||
|
|
||||||
|
sql1 = "SELECT\n"
|
||||||
|
+ "页面,\n"
|
||||||
|
+ "SUM(访问次数)\n"
|
||||||
|
+ "FROM\n"
|
||||||
|
+ "超音数\n"
|
||||||
|
+ "WHERE\n"
|
||||||
|
+ "数据日期 >= '2023-10-26'\n"
|
||||||
|
+ "AND 数据日期 <= '2023-11-09'\n"
|
||||||
|
+ "AND department = \"HR\"\n"
|
||||||
|
+ "GROUP BY\n"
|
||||||
|
+ "页面\n"
|
||||||
|
+ "LIMIT\n"
|
||||||
|
+ "365";
|
||||||
|
|
||||||
|
sql2 = "SELECT\n"
|
||||||
|
+ "页面,\n"
|
||||||
|
+ "SUM(访问次数)\n"
|
||||||
|
+ "FROM\n"
|
||||||
|
+ "超音数\n"
|
||||||
|
+ "WHERE\n"
|
||||||
|
+ "department = \"HR\"\n"
|
||||||
|
+ "AND 数据日期 >= '2023-10-26'\n"
|
||||||
|
+ "AND 数据日期 <= '2023-11-09'\n"
|
||||||
|
+ "GROUP BY\n"
|
||||||
|
+ "页面\n"
|
||||||
|
+ "LIMIT\n"
|
||||||
|
+ "365";
|
||||||
|
|
||||||
|
Assert.equals(SqlParserEqualHelper.equals(sql1, sql2), true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -8,7 +8,9 @@ com.tencent.supersonic.chat.api.component.SchemaMapper=\
|
|||||||
com.tencent.supersonic.chat.api.component.SemanticParser=\
|
com.tencent.supersonic.chat.api.component.SemanticParser=\
|
||||||
com.tencent.supersonic.chat.parser.rule.RuleBasedParser, \
|
com.tencent.supersonic.chat.parser.rule.RuleBasedParser, \
|
||||||
com.tencent.supersonic.chat.parser.llm.s2sql.LLMS2SQLParser, \
|
com.tencent.supersonic.chat.parser.llm.s2sql.LLMS2SQLParser, \
|
||||||
com.tencent.supersonic.chat.parser.plugin.function.FunctionBasedParser
|
com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingBasedParser, \
|
||||||
|
com.tencent.supersonic.chat.parser.plugin.function.FunctionBasedParser, \
|
||||||
|
com.tencent.supersonic.chat.parser.QueryTypeParser
|
||||||
|
|
||||||
com.tencent.supersonic.chat.api.component.SemanticCorrector=\
|
com.tencent.supersonic.chat.api.component.SemanticCorrector=\
|
||||||
com.tencent.supersonic.chat.corrector.SchemaCorrector, \
|
com.tencent.supersonic.chat.corrector.SchemaCorrector, \
|
||||||
|
|||||||
@@ -9,7 +9,8 @@ com.tencent.supersonic.chat.api.component.SemanticParser=\
|
|||||||
com.tencent.supersonic.chat.parser.rule.RuleBasedParser, \
|
com.tencent.supersonic.chat.parser.rule.RuleBasedParser, \
|
||||||
com.tencent.supersonic.chat.parser.llm.s2sql.LLMS2SQLParser, \
|
com.tencent.supersonic.chat.parser.llm.s2sql.LLMS2SQLParser, \
|
||||||
com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingBasedParser, \
|
com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingBasedParser, \
|
||||||
com.tencent.supersonic.chat.parser.plugin.function.FunctionBasedParser
|
com.tencent.supersonic.chat.parser.plugin.function.FunctionBasedParser, \
|
||||||
|
com.tencent.supersonic.chat.parser.QueryTypeParser
|
||||||
|
|
||||||
com.tencent.supersonic.chat.api.component.SemanticCorrector=\
|
com.tencent.supersonic.chat.api.component.SemanticCorrector=\
|
||||||
com.tencent.supersonic.chat.corrector.SchemaCorrector, \
|
com.tencent.supersonic.chat.corrector.SchemaCorrector, \
|
||||||
|
|||||||
Reference in New Issue
Block a user