mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-18 00:07:21 +00:00
(improvement)(headless)Introduce DetailDimensionQuery as a type of rule-based parsing query.
This commit is contained in:
@@ -0,0 +1,30 @@
|
||||
package com.tencent.supersonic.headless.chat.query.rule.detail;
|
||||
|
||||
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.DIMENSION;
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ID;
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.VALUE;
|
||||
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.OptionType.OPTIONAL;
|
||||
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
|
||||
@Component
|
||||
public class DetailDimensionQuery extends DetailSemanticQuery {
|
||||
|
||||
public static final String QUERY_MODE = "DETAIL_DIMENSION";
|
||||
|
||||
public DetailDimensionQuery() {
|
||||
super();
|
||||
queryMatcher.addOption(DIMENSION, REQUIRED, AT_LEAST, 1);
|
||||
queryMatcher.addOption(VALUE, OPTIONAL, AT_LEAST, 0);
|
||||
queryMatcher.addOption(ID, OPTIONAL, AT_LEAST, 0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getQueryMode() {
|
||||
return QUERY_MODE;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.chat.query.rule.detail;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ENTITY;
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.VALUE;
|
||||
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
@@ -16,6 +17,7 @@ public class DetailFilterQuery extends DetailListQuery {
|
||||
public DetailFilterQuery() {
|
||||
super();
|
||||
queryMatcher.addOption(VALUE, REQUIRED, AT_LEAST, 1);
|
||||
queryMatcher.addOption(ENTITY, REQUIRED, AT_LEAST, 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -5,10 +5,8 @@ import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeMode;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@@ -24,8 +22,6 @@ public abstract class DetailSemanticQuery extends RuleSemanticQuery {
|
||||
|
||||
public DetailSemanticQuery() {
|
||||
super();
|
||||
queryMatcher.addOption(SchemaElementType.ENTITY, QueryMatchOption.OptionType.REQUIRED,
|
||||
QueryMatchOption.RequireNumberType.AT_LEAST, 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -122,7 +122,7 @@ public class S2ArtistDemo extends S2BaseDemo {
|
||||
DimensionType.categorical.name(), 1, 1));
|
||||
dimensions.add(new Dim("代表作", "song_name",
|
||||
DimensionType.categorical.name(), 1));
|
||||
dimensions.add(new Dim("风格", "genre",
|
||||
dimensions.add(new Dim("流派", "genre",
|
||||
DimensionType.categorical.name(), 1, 1));
|
||||
modelDetail.setDimensions(dimensions);
|
||||
|
||||
@@ -191,7 +191,7 @@ public class S2ArtistDemo extends S2BaseDemo {
|
||||
agent.setDescription("帮助您用自然语言进行圈选,支持多条件组合筛选");
|
||||
agent.setStatus(1);
|
||||
agent.setEnableSearch(1);
|
||||
agent.setExamples(Lists.newArrayList("国风风格艺人", "港台地区的艺人", "风格为流行的艺人"));
|
||||
agent.setExamples(Lists.newArrayList("国风流派艺人", "港台地区的艺人", "流派为流行的艺人"));
|
||||
AgentConfig agentConfig = new AgentConfig();
|
||||
RuleParserTool ruleQueryTool = new RuleParserTool();
|
||||
ruleQueryTool.setId("0");
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.tencent.supersonic.chat;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
@@ -7,7 +8,9 @@ import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.detail.DetailDimensionQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.detail.DetailFilterQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.detail.DetailIdQuery;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.junit.jupiter.api.Test;
|
||||
@@ -18,36 +21,78 @@ import org.springframework.boot.test.context.SpringBootTest;
|
||||
public class TagTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void queryTest_tag_list_filter() throws Exception {
|
||||
public void test_detail_dimension() throws Exception {
|
||||
QueryResult actualResult = submitNewChat("周杰伦流派和代表作", DataUtils.tagAgentId);
|
||||
|
||||
log.info("queryTest_tag_list_filter start");
|
||||
QueryResult actualResult = submitNewChat("爱情、流行类型的艺人", DataUtils.tagAgentId);
|
||||
log.info("actualResult:{}", actualResult);
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
||||
expectedResult.setChatContext(expectedParseInfo);
|
||||
|
||||
expectedResult.setQueryMode(DetailDimensionQuery.QUERY_MODE);
|
||||
expectedParseInfo.setQueryType(QueryType.DETAIL);
|
||||
expectedParseInfo.setAggType(AggregateTypeEnum.NONE);
|
||||
|
||||
QueryFilter dimensionFilter = DataUtils.getFilter("singer_name", FilterOperatorEnum.EQUALS,
|
||||
"周杰伦", "歌手名", 8L);
|
||||
expectedParseInfo.getDimensionFilters().add(dimensionFilter);
|
||||
|
||||
expectedParseInfo.getDimensions().addAll(Lists.newArrayList(
|
||||
SchemaElement.builder().name("流派").build(),
|
||||
SchemaElement.builder().name("代表作").build()));
|
||||
|
||||
assertQueryResult(expectedResult, actualResult);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_detail_id() throws Exception {
|
||||
QueryResult actualResult = submitNewChat("周杰伦", DataUtils.tagAgentId);
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
||||
expectedResult.setChatContext(expectedParseInfo);
|
||||
|
||||
expectedResult.setQueryMode(DetailIdQuery.QUERY_MODE);
|
||||
expectedParseInfo.setQueryType(QueryType.DETAIL);
|
||||
expectedParseInfo.setAggType(AggregateTypeEnum.NONE);
|
||||
|
||||
QueryFilter dimensionFilter = DataUtils.getFilter("singer_name", FilterOperatorEnum.EQUALS,
|
||||
"周杰伦", "歌手名", 8L);
|
||||
expectedParseInfo.getDimensionFilters().add(dimensionFilter);
|
||||
|
||||
expectedParseInfo.getMetrics().add(SchemaElement.builder().name("播放量").build());
|
||||
expectedParseInfo.getDimensions().addAll(Lists.newArrayList(
|
||||
SchemaElement.builder().name("歌手名").build(),
|
||||
SchemaElement.builder().name("活跃区域").build(),
|
||||
SchemaElement.builder().name("流派").build(),
|
||||
SchemaElement.builder().name("代表作").build()
|
||||
));
|
||||
|
||||
assertQueryResult(expectedResult, actualResult);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_detail_list_filter() throws Exception {
|
||||
QueryResult actualResult = submitNewChat("国风艺人", DataUtils.tagAgentId);
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
||||
expectedResult.setChatContext(expectedParseInfo);
|
||||
|
||||
expectedResult.setQueryMode(DetailFilterQuery.QUERY_MODE);
|
||||
expectedParseInfo.setQueryType(QueryType.DETAIL);
|
||||
expectedParseInfo.setAggType(AggregateTypeEnum.NONE);
|
||||
|
||||
QueryFilter dimensionFilter = DataUtils.getFilter("genre", FilterOperatorEnum.EQUALS,
|
||||
"流行", "风格", 7L);
|
||||
"国风", "流派", 7L);
|
||||
expectedParseInfo.getDimensionFilters().add(dimensionFilter);
|
||||
|
||||
SchemaElement metric = SchemaElement.builder().name("播放量").build();
|
||||
expectedParseInfo.getMetrics().add(metric);
|
||||
|
||||
SchemaElement dim1 = SchemaElement.builder().name("歌手名").build();
|
||||
SchemaElement dim2 = SchemaElement.builder().name("活跃区域").build();
|
||||
SchemaElement dim3 = SchemaElement.builder().name("风格").build();
|
||||
SchemaElement dim4 = SchemaElement.builder().name("代表作").build();
|
||||
expectedParseInfo.getDimensions().add(dim1);
|
||||
expectedParseInfo.getDimensions().add(dim2);
|
||||
expectedParseInfo.getDimensions().add(dim3);
|
||||
expectedParseInfo.getDimensions().add(dim4);
|
||||
|
||||
expectedParseInfo.setQueryType(QueryType.DETAIL);
|
||||
expectedParseInfo.getMetrics().add(SchemaElement.builder().name("播放量").build());
|
||||
expectedParseInfo.getDimensions().addAll(Lists.newArrayList(
|
||||
SchemaElement.builder().name("歌手名").build(),
|
||||
SchemaElement.builder().name("活跃区域").build(),
|
||||
SchemaElement.builder().name("流派").build(),
|
||||
SchemaElement.builder().name("代表作").build()
|
||||
));
|
||||
|
||||
assertQueryResult(expectedResult, actualResult);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user