mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-15 06:27:21 +00:00
[improvement](chat) remove nativeQuery config in chat (#394)
This commit is contained in:
@@ -5,7 +5,7 @@ import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.QueryType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
@@ -26,7 +26,7 @@ import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
/**
|
||||
* Query type parser, determine if the query is a metric query, a entity query,
|
||||
* Query type parser, determine if the query is a metric query, an entity query,
|
||||
* or another type of query.
|
||||
*/
|
||||
@Slf4j
|
||||
|
||||
@@ -17,6 +17,7 @@ import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
|
||||
import com.tencent.supersonic.common.pojo.Aggregator;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
import com.tencent.supersonic.common.pojo.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
|
||||
@@ -89,7 +90,7 @@ public class MetricInterpretQuery extends LLMSemanticQuery {
|
||||
protected QueryStructReq convertQueryStruct() {
|
||||
QueryStructReq queryStructReq = QueryReqBuilder.buildStructReq(parseInfo);
|
||||
fillAggregator(queryStructReq, parseInfo.getMetrics());
|
||||
queryStructReq.setNativeQuery(true);
|
||||
queryStructReq.setQueryType(QueryType.ENTITY);
|
||||
return queryStructReq;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
package com.tencent.supersonic.chat.query.rule.entity;
|
||||
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ENTITY;
|
||||
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.pojo.QueryType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
|
||||
@@ -11,17 +16,12 @@ import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.chat.service.ConfigService;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
import java.time.LocalDate;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ENTITY;
|
||||
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public abstract class EntitySemanticQuery extends RuleSemanticQuery {
|
||||
@@ -81,7 +81,7 @@ public abstract class EntitySemanticQuery extends RuleSemanticQuery {
|
||||
public void fillParseInfo(Long modelId, QueryContext queryContext, ChatContext chatContext) {
|
||||
super.fillParseInfo(modelId, queryContext, chatContext);
|
||||
|
||||
parseInfo.setNativeQuery(true);
|
||||
parseInfo.setQueryType(QueryType.ENTITY);
|
||||
parseInfo.setLimit(ENTITY_MAX_RESULTS);
|
||||
if (parseInfo.getDateInfo() == null) {
|
||||
ConfigService configService = ContextUtils.getBean(ConfigService.class);
|
||||
|
||||
@@ -9,12 +9,6 @@ import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.llm.interpret.MetricInterpretQuery;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
public class EntityInfoExecuteResponder implements ExecuteResponder {
|
||||
|
||||
@@ -33,29 +27,6 @@ public class EntityInfoExecuteResponder implements ExecuteResponder {
|
||||
EntityInfo entityInfo = semanticService.getEntityInfo(semanticParseInfo, user);
|
||||
queryResult.setEntityInfo(entityInfo);
|
||||
|
||||
String primaryEntityBizName = semanticService.getPrimaryEntityBizName(entityInfo);
|
||||
if (StringUtils.isEmpty(primaryEntityBizName)
|
||||
|| CollectionUtils.isEmpty(queryResult.getQueryColumns())) {
|
||||
return;
|
||||
}
|
||||
boolean existPrimaryEntityName = queryResult.getQueryColumns().stream()
|
||||
.anyMatch(queryColumn -> primaryEntityBizName.equals(queryColumn.getNameEn()));
|
||||
|
||||
semanticParseInfo.setNativeQuery(existPrimaryEntityName);
|
||||
|
||||
if (!existPrimaryEntityName) {
|
||||
return;
|
||||
}
|
||||
List<Map<String, Object>> queryResults = queryResult.getQueryResults();
|
||||
List<String> entities = queryResults.stream()
|
||||
.map(entry -> entry.get(primaryEntityBizName))
|
||||
.filter(Objects::nonNull)
|
||||
.map(String::valueOf)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
if (CollectionUtils.isEmpty(entities)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import com.tencent.supersonic.chat.query.llm.interpret.MetricInterpretQuery;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import java.util.List;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
public class EntityInfoParseResponder implements ParseResponder {
|
||||
@@ -37,14 +36,6 @@ public class EntityInfoParseResponder implements ParseResponder {
|
||||
|| QueryManager.isMetricQuery(queryMode)) {
|
||||
parseInfo.setEntityInfo(entityInfo);
|
||||
}
|
||||
//2. set native value
|
||||
String primaryEntityBizName = semanticService.getPrimaryEntityBizName(entityInfo);
|
||||
if (StringUtils.isNotEmpty(primaryEntityBizName)) {
|
||||
//if exist primaryEntityBizName in parseInfo's dimensions, set nativeQuery to true
|
||||
boolean existPrimaryEntityBizName = parseInfo.getDimensions().stream()
|
||||
.anyMatch(schemaElement -> primaryEntityBizName.equalsIgnoreCase(schemaElement.getBizName()));
|
||||
parseInfo.setNativeQuery(existPrimaryEntityBizName);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import static com.tencent.supersonic.common.pojo.Constants.WEEK;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.common.pojo.QueryType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq;
|
||||
@@ -233,7 +234,7 @@ public class SemanticService {
|
||||
modelInfo.setEntityId(entities.get(0));
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
semanticParseInfo.setModel(modelSchema.getModel());
|
||||
semanticParseInfo.setNativeQuery(true);
|
||||
semanticParseInfo.setQueryType(QueryType.ENTITY);
|
||||
semanticParseInfo.setMetrics(getMetrics(modelInfo));
|
||||
semanticParseInfo.setDimensions(getDimensions(modelInfo));
|
||||
DateConf dateInfo = new DateConf();
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package com.tencent.supersonic.chat.service.impl;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.QueryType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
@@ -79,13 +80,13 @@ public class ParserInfoServiceImpl implements ParseInfoService {
|
||||
parseInfo.setMetrics(metrics);
|
||||
|
||||
if (SqlParserSelectFunctionHelper.hasAggregateFunction(sqlInfo.getCorrectS2SQL())) {
|
||||
parseInfo.setNativeQuery(false);
|
||||
parseInfo.setQueryType(QueryType.METRIC);
|
||||
List<String> groupByFields = SqlParserSelectHelper.getGroupByFields(sqlInfo.getCorrectS2SQL());
|
||||
List<String> groupByDimensions = getFieldsExceptDate(groupByFields);
|
||||
parseInfo.setDimensions(
|
||||
getElements(parseInfo.getModelId(), groupByDimensions, semanticSchema.getDimensions()));
|
||||
} else {
|
||||
parseInfo.setNativeQuery(true);
|
||||
parseInfo.setQueryType(QueryType.ENTITY);
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getCorrectS2SQL());
|
||||
List<String> selectDimensions = getFieldsExceptDate(selectFields);
|
||||
parseInfo.setDimensions(
|
||||
|
||||
@@ -42,6 +42,7 @@ import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.utils.SolvedQueryManager;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
import com.tencent.supersonic.common.pojo.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
@@ -609,7 +610,7 @@ public class QueryServiceImpl implements QueryService {
|
||||
queryStructReq.setDateInfo(dateConf);
|
||||
queryStructReq.setLimit(20L);
|
||||
queryStructReq.setModelId(dimensionValueReq.getModelId());
|
||||
queryStructReq.setNativeQuery(false);
|
||||
queryStructReq.setQueryType(QueryType.OTHER);
|
||||
List<String> groups = new ArrayList<>();
|
||||
groups.add(dimensionValueReq.getBizName());
|
||||
queryStructReq.setGroups(groups);
|
||||
|
||||
@@ -9,6 +9,7 @@ import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.Filter;
|
||||
import com.tencent.supersonic.common.pojo.Order;
|
||||
import com.tencent.supersonic.common.pojo.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
@@ -36,7 +37,7 @@ public class QueryReqBuilder {
|
||||
public static QueryStructReq buildStructReq(SemanticParseInfo parseInfo) {
|
||||
QueryStructReq queryStructCmd = new QueryStructReq();
|
||||
queryStructCmd.setModelId(parseInfo.getModelId());
|
||||
queryStructCmd.setNativeQuery(parseInfo.getNativeQuery());
|
||||
queryStructCmd.setQueryType(parseInfo.getQueryType());
|
||||
queryStructCmd.setDateInfo(rewrite2Between(parseInfo.getDateInfo()));
|
||||
|
||||
List<Filter> dimensionFilters = parseInfo.getDimensionFilters().stream()
|
||||
@@ -231,7 +232,7 @@ public class QueryReqBuilder {
|
||||
public static QueryStructReq buildStructRatioReq(SemanticParseInfo parseInfo, SchemaElement metric,
|
||||
AggOperatorEnum aggOperatorEnum) {
|
||||
QueryStructReq queryStructCmd = buildStructReq(parseInfo);
|
||||
queryStructCmd.setNativeQuery(false);
|
||||
queryStructCmd.setQueryType(QueryType.METRIC);
|
||||
queryStructCmd.setOrders(new ArrayList<>());
|
||||
List<Aggregator> aggregators = new ArrayList<>();
|
||||
Aggregator ratioRoll = new Aggregator(metric.getBizName(), aggOperatorEnum);
|
||||
|
||||
@@ -5,6 +5,7 @@ import com.tencent.supersonic.common.pojo.Aggregator;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
|
||||
import com.tencent.supersonic.common.pojo.Order;
|
||||
import com.tencent.supersonic.common.pojo.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.DateModeUtils;
|
||||
@@ -29,7 +30,7 @@ class QueryReqBuilderTest {
|
||||
init();
|
||||
QueryStructReq queryStructReq = new QueryStructReq();
|
||||
queryStructReq.setModelId(1L);
|
||||
queryStructReq.setNativeQuery(false);
|
||||
queryStructReq.setQueryType(QueryType.METRIC);
|
||||
queryStructReq.setModelName("内容库");
|
||||
|
||||
Aggregator aggregator = new Aggregator();
|
||||
@@ -55,7 +56,7 @@ class QueryReqBuilderTest {
|
||||
"SELECT department, SUM(pv) FROM 内容库 WHERE (sys_imp_date IN ('2023-08-01')) "
|
||||
+ "GROUP BY department ORDER BY uv LIMIT 2000", queryS2SQLReq.getSql());
|
||||
|
||||
queryStructReq.setNativeQuery(true);
|
||||
queryStructReq.setQueryType(QueryType.ENTITY);
|
||||
queryS2SQLReq = queryStructReq.convert(queryStructReq);
|
||||
Assert.assertEquals(
|
||||
"SELECT department, pv FROM 内容库 WHERE (sys_imp_date IN ('2023-08-01')) "
|
||||
|
||||
Reference in New Issue
Block a user