(improvement)(Headless) Change queryMode TAG_MODE to DETAIL_MODE (#1050)

This commit is contained in:
LXW
2024-05-30 10:49:09 +08:00
committed by GitHub
parent 6aaf471582
commit 40b3142730
18 changed files with 52 additions and 75 deletions

View File

@@ -11,13 +11,13 @@ public enum QueryType {
/**
* queries with tag-based entity targeting
*/
TAG,
DETAIL,
/**
* queries with ID-based entity selection
*/
ID;
public boolean isNativeAggQuery() {
return TAG.equals(this);
return DETAIL.equals(this);
}
}

View File

@@ -13,16 +13,15 @@ import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.headless.core.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.headless.core.pojo.ChatContext;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
/**
* QueryTypeParser resolves query type as either METRIC or TAG, or ID.
*/
@@ -73,7 +72,7 @@ public class QueryTypeParser implements SemanticParser {
//If all the fields in the SELECT/WHERE statement are of tag type.
if (CollectionUtils.isNotEmpty(tags)
&& tags.containsAll(selectWhereFilterByTimeFields)) {
return QueryType.TAG;
return QueryType.DETAIL;
}
}
}

View File

@@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.core.chat.query;
import com.tencent.supersonic.headless.core.chat.query.llm.LLMSemanticQuery;
import com.tencent.supersonic.headless.core.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.headless.core.chat.query.rule.metric.MetricSemanticQuery;
import com.tencent.supersonic.headless.core.chat.query.rule.tag.TagSemanticQuery;
import com.tencent.supersonic.headless.core.chat.query.rule.detail.DetailSemanticQuery;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@@ -70,7 +70,7 @@ public class QueryManager {
if (queryMode == null || !ruleQueryMap.containsKey(queryMode)) {
return false;
}
return ruleQueryMap.get(queryMode) instanceof TagSemanticQuery;
return ruleQueryMap.get(queryMode) instanceof DetailSemanticQuery;
}
public static RuleSemanticQuery getRuleQuery(String queryMode) {

View File

@@ -1,19 +1,19 @@
package com.tencent.supersonic.headless.core.chat.query.rule.tag;
package com.tencent.supersonic.headless.core.chat.query.rule.detail;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.VALUE;
import static com.tencent.supersonic.headless.core.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.headless.core.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
@Slf4j
@Component
public class TagFilterQuery extends TagListQuery {
public class DetailFilterQuery extends DetailListQuery {
public static final String QUERY_MODE = "TAG_LIST_FILTER";
public static final String QUERY_MODE = "DETAIL_LIST_FILTER";
public TagFilterQuery() {
public DetailFilterQuery() {
super();
queryMatcher.addOption(VALUE, REQUIRED, AT_LEAST, 1);
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.headless.core.chat.query.rule.tag;
package com.tencent.supersonic.headless.core.chat.query.rule.detail;
import org.springframework.stereotype.Component;
@@ -7,11 +7,11 @@ import static com.tencent.supersonic.headless.core.chat.query.rule.QueryMatchOpt
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ID;
@Component
public class TagIdQuery extends TagListQuery {
public class DetailIdQuery extends DetailListQuery {
public static final String QUERY_MODE = "TAG_ID";
public static final String QUERY_MODE = "DETAIL_ID";
public TagIdQuery() {
public DetailIdQuery() {
super();
queryMatcher.addOption(ID, REQUIRED, AT_LEAST, 1);
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.headless.core.chat.query.rule.tag;
package com.tencent.supersonic.headless.core.chat.query.rule.detail;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.Order;
@@ -16,7 +16,7 @@ import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
public abstract class TagListQuery extends TagSemanticQuery {
public abstract class DetailListQuery extends DetailSemanticQuery {
@Override
public void fillParseInfo(QueryContext queryContext, ChatContext chatContext) {

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.headless.core.chat.query.rule.tag;
package com.tencent.supersonic.headless.core.chat.query.rule.detail;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.enums.QueryType;
@@ -20,11 +20,11 @@ import static com.tencent.supersonic.headless.core.chat.query.rule.QueryMatchOpt
import static com.tencent.supersonic.headless.core.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
@Slf4j
public abstract class TagSemanticQuery extends RuleSemanticQuery {
public abstract class DetailSemanticQuery extends RuleSemanticQuery {
private static final Long TAG_MAX_RESULTS = 500L;
private static final Long DETAIL_MAX_RESULTS = 500L;
public TagSemanticQuery() {
public DetailSemanticQuery() {
super();
queryMatcher.addOption(ENTITY, REQUIRED, AT_LEAST, 1);
}
@@ -39,8 +39,8 @@ public abstract class TagSemanticQuery extends RuleSemanticQuery {
public void fillParseInfo(QueryContext queryContext, ChatContext chatContext) {
super.fillParseInfo(queryContext, chatContext);
parseInfo.setQueryType(QueryType.TAG);
parseInfo.setLimit(TAG_MAX_RESULTS);
parseInfo.setQueryType(QueryType.DETAIL);
parseInfo.setLimit(DETAIL_MAX_RESULTS);
if (parseInfo.getDateInfo() == null) {
DataSetSchema dataSetSchema =
queryContext.getSemanticSchema().getDataSetSchemaMap().get(parseInfo.getDataSetId());

View File

@@ -1,24 +0,0 @@
package com.tencent.supersonic.headless.core.chat.query.rule.tag;
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ID;
import static com.tencent.supersonic.headless.core.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.headless.core.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
import org.springframework.stereotype.Component;
@Component
public class TagDetailQuery extends TagSemanticQuery {
public static final String QUERY_MODE = "TAG_DETAIL";
public TagDetailQuery() {
super();
queryMatcher.addOption(ID, REQUIRED, AT_LEAST, 1);
}
@Override
public String getQueryMode() {
return QUERY_MODE;
}
}

View File

@@ -37,7 +37,7 @@ public class S2SqlDateHelper {
return Pair.of(defaultDate, defaultDate);
}
TimeDefaultConfig defaultConfig = dataSetSchema.getMetricTypeTimeDefaultConfig();
if (QueryType.TAG.equals(queryType)) {
if (QueryType.DETAIL.equals(queryType)) {
defaultConfig = dataSetSchema.getTagTypeTimeDefaultConfig();
}
return getDefaultDate(defaultDate, defaultConfig);

View File

@@ -57,11 +57,11 @@ class S2SqlDateHelperTest {
Long dataSetId = 1L;
QueryContext queryContext = buildQueryContext(dataSetId);
Pair<String, String> startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, null, QueryType.TAG);
Pair<String, String> startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, null, QueryType.DETAIL);
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(0));
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(0));
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.TAG);
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.DETAIL);
Assert.assertNull(startEndDate.getLeft());
Assert.assertNull(startEndDate.getRight());
@@ -74,7 +74,7 @@ class S2SqlDateHelperTest {
queryConfig.getTagTypeDefaultConfig().setTimeDefaultConfig(timeDefaultConfig);
queryConfig.getMetricTypeDefaultConfig().setTimeDefaultConfig(timeDefaultConfig);
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.TAG);
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.DETAIL);
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(20));
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(20));
@@ -88,7 +88,7 @@ class S2SqlDateHelperTest {
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(2));
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(1));
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.TAG);
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.DETAIL);
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(2));
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(1));

View File

@@ -94,7 +94,7 @@ public class ParseInfoProcessor implements ResultProcessor {
List<String> groupByFields = SqlSelectHelper.getGroupByFields(sqlInfo.getCorrectS2SQL());
List<String> groupByDimensions = getFieldsExceptDate(groupByFields);
parseInfo.setDimensions(getElements(dataSetId, groupByDimensions, semanticSchema.getDimensions()));
} else if (QueryType.TAG.equals(parseInfo.getQueryType())) {
} else if (QueryType.DETAIL.equals(parseInfo.getQueryType())) {
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getCorrectS2SQL());
List<String> selectDimensions = getFieldsExceptDate(selectFields);
parseInfo.setDimensions(getElements(dataSetId, selectDimensions, semanticSchema.getDimensions()));

View File

@@ -600,7 +600,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
sqlInfo.setCorrectS2SQL(querySqlReq.getSql());
sqlInfo.setS2SQL(querySqlReq.getSql());
semanticParseInfo.setSqlInfo(sqlInfo);
semanticParseInfo.setQueryType(QueryType.TAG);
semanticParseInfo.setQueryType(QueryType.DETAIL);
Long dataSetId = querySqlReq.getDataSetId();
if (Objects.isNull(dataSetId)) {

View File

@@ -223,7 +223,8 @@ public class DataSetServiceImpl
queryReq = new QuerySqlReq();
}
BeanUtils.copyProperties(queryDataSetReq, queryReq);
if (Objects.nonNull(queryDataSetReq.getQueryType()) && QueryType.TAG.equals(queryDataSetReq.getQueryType())) {
if (Objects.nonNull(queryDataSetReq.getQueryType())
&& QueryType.DETAIL.equals(queryDataSetReq.getQueryType())) {
queryReq.setInnerLayerNative(true);
}
return queryReq;

View File

@@ -151,7 +151,7 @@ public class SemanticService {
DataSetSchema dataSetSchema, User user) {
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
semanticParseInfo.setDataSet(dataSetSchema.getDataSet());
semanticParseInfo.setQueryType(QueryType.TAG);
semanticParseInfo.setQueryType(QueryType.DETAIL);
semanticParseInfo.setMetrics(getMetrics(entityInfo));
semanticParseInfo.setDimensions(getDimensions(entityInfo));
DateConf dateInfo = new DateConf();

View File

@@ -28,6 +28,15 @@ import com.tencent.supersonic.headless.core.adaptor.db.DbAdaptorFactory;
import com.tencent.supersonic.headless.core.pojo.DataSetQueryParam;
import com.tencent.supersonic.headless.core.pojo.QueryStatement;
import com.tencent.supersonic.headless.core.utils.SqlGenerateUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.util.Strings;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
@@ -37,14 +46,6 @@ import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.util.Strings;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
@Component
@Slf4j
@@ -242,7 +243,7 @@ public class QueryReqConverter {
private QueryType getQueryType(AggOption aggOption) {
boolean isAgg = AggOption.isAgg(aggOption);
QueryType queryType = QueryType.TAG;
QueryType queryType = QueryType.DETAIL;
if (isAgg) {
queryType = QueryType.METRIC;
}

View File

@@ -58,7 +58,7 @@ class QueryReqBuilderTest {
+ "WHERE (sys_imp_date IN ('2023-08-01')) GROUP "
+ "BY department ORDER BY uv LIMIT 2000", querySQLReq.getSql());
queryStructReq.setQueryType(QueryType.TAG);
queryStructReq.setQueryType(QueryType.DETAIL);
querySQLReq = queryStructReq.convert();
Assert.assertEquals(
"SELECT department, pv FROM 内容库 WHERE (sys_imp_date IN ('2023-08-01')) "

View File

@@ -8,7 +8,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.core.chat.query.rule.tag.TagFilterQuery;
import com.tencent.supersonic.headless.core.chat.query.rule.detail.DetailFilterQuery;
import com.tencent.supersonic.util.DataUtils;
import org.junit.jupiter.api.Test;
@@ -22,7 +22,7 @@ public class TagTest extends BaseTest {
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
expectedResult.setChatContext(expectedParseInfo);
expectedResult.setQueryMode(TagFilterQuery.QUERY_MODE);
expectedResult.setQueryMode(DetailFilterQuery.QUERY_MODE);
expectedParseInfo.setAggType(AggregateTypeEnum.NONE);
QueryFilter dimensionFilter = DataUtils.getFilter("genre", FilterOperatorEnum.EQUALS,
@@ -42,7 +42,7 @@ public class TagTest extends BaseTest {
expectedParseInfo.getDimensions().add(dim4);
expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.BETWEEN, startDay, startDay, 7));
expectedParseInfo.setQueryType(QueryType.TAG);
expectedParseInfo.setQueryType(QueryType.DETAIL);
assertQueryResult(expectedResult, actualResult);
}

View File

@@ -27,7 +27,7 @@ public class QueryByStructTest extends BaseTest {
@Test
public void testDetailQuery() throws Exception {
QueryStructReq queryStructReq = buildQueryStructReq(Arrays.asList("user_name", "department"),
QueryType.TAG);
QueryType.DETAIL);
SemanticQueryResp semanticQueryResp = queryService.queryByReq(queryStructReq, User.getFakeUser());
assertEquals(3, semanticQueryResp.getColumns().size());
QueryColumn firstColumn = semanticQueryResp.getColumns().get(0);