(improvement)(headless)Introduce DetailDimensionQuery as a type of rule-based parsing query.

This commit is contained in:
jerryjzhang
2024-08-10 18:27:50 +08:00
parent ecc651e12d
commit ba9e6afa51
5 changed files with 97 additions and 24 deletions

View File

@@ -0,0 +1,30 @@
package com.tencent.supersonic.headless.chat.query.rule.detail;
import org.springframework.stereotype.Component;
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.DIMENSION;
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ID;
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.VALUE;
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.OptionType.OPTIONAL;
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
@Component
public class DetailDimensionQuery extends DetailSemanticQuery {
public static final String QUERY_MODE = "DETAIL_DIMENSION";
public DetailDimensionQuery() {
super();
queryMatcher.addOption(DIMENSION, REQUIRED, AT_LEAST, 1);
queryMatcher.addOption(VALUE, OPTIONAL, AT_LEAST, 0);
queryMatcher.addOption(ID, OPTIONAL, AT_LEAST, 0);
}
@Override
public String getQueryMode() {
return QUERY_MODE;
}
}

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.chat.query.rule.detail;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ENTITY;
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.VALUE; import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.VALUE;
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.OptionType.REQUIRED; import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST; import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
@@ -16,6 +17,7 @@ public class DetailFilterQuery extends DetailListQuery {
public DetailFilterQuery() { public DetailFilterQuery() {
super(); super();
queryMatcher.addOption(VALUE, REQUIRED, AT_LEAST, 1); queryMatcher.addOption(VALUE, REQUIRED, AT_LEAST, 1);
queryMatcher.addOption(ENTITY, REQUIRED, AT_LEAST, 1);
} }
@Override @Override

View File

@@ -5,10 +5,8 @@ import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeMode; import com.tencent.supersonic.common.pojo.enums.TimeMode;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig; import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption;
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery; import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@@ -24,8 +22,6 @@ public abstract class DetailSemanticQuery extends RuleSemanticQuery {
public DetailSemanticQuery() { public DetailSemanticQuery() {
super(); super();
queryMatcher.addOption(SchemaElementType.ENTITY, QueryMatchOption.OptionType.REQUIRED,
QueryMatchOption.RequireNumberType.AT_LEAST, 1);
} }
@Override @Override

View File

@@ -122,7 +122,7 @@ public class S2ArtistDemo extends S2BaseDemo {
DimensionType.categorical.name(), 1, 1)); DimensionType.categorical.name(), 1, 1));
dimensions.add(new Dim("代表作", "song_name", dimensions.add(new Dim("代表作", "song_name",
DimensionType.categorical.name(), 1)); DimensionType.categorical.name(), 1));
dimensions.add(new Dim("风格", "genre", dimensions.add(new Dim("流派", "genre",
DimensionType.categorical.name(), 1, 1)); DimensionType.categorical.name(), 1, 1));
modelDetail.setDimensions(dimensions); modelDetail.setDimensions(dimensions);
@@ -191,7 +191,7 @@ public class S2ArtistDemo extends S2BaseDemo {
agent.setDescription("帮助您用自然语言进行圈选,支持多条件组合筛选"); agent.setDescription("帮助您用自然语言进行圈选,支持多条件组合筛选");
agent.setStatus(1); agent.setStatus(1);
agent.setEnableSearch(1); agent.setEnableSearch(1);
agent.setExamples(Lists.newArrayList("国风风格艺人", "港台地区的艺人", "风格为流行的艺人")); agent.setExamples(Lists.newArrayList("国风流派艺人", "港台地区的艺人", "流派为流行的艺人"));
AgentConfig agentConfig = new AgentConfig(); AgentConfig agentConfig = new AgentConfig();
RuleParserTool ruleQueryTool = new RuleParserTool(); RuleParserTool ruleQueryTool = new RuleParserTool();
ruleQueryTool.setId("0"); ruleQueryTool.setId("0");

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.chat; package com.tencent.supersonic.chat;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum; import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
@@ -7,7 +8,9 @@ import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.chat.query.rule.detail.DetailDimensionQuery;
import com.tencent.supersonic.headless.chat.query.rule.detail.DetailFilterQuery; import com.tencent.supersonic.headless.chat.query.rule.detail.DetailFilterQuery;
import com.tencent.supersonic.headless.chat.query.rule.detail.DetailIdQuery;
import com.tencent.supersonic.util.DataUtils; import com.tencent.supersonic.util.DataUtils;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@@ -18,36 +21,78 @@ import org.springframework.boot.test.context.SpringBootTest;
public class TagTest extends BaseTest { public class TagTest extends BaseTest {
@Test @Test
public void queryTest_tag_list_filter() throws Exception { public void test_detail_dimension() throws Exception {
QueryResult actualResult = submitNewChat("周杰伦流派和代表作", DataUtils.tagAgentId);
log.info("queryTest_tag_list_filter start"); QueryResult expectedResult = new QueryResult();
QueryResult actualResult = submitNewChat("爱情、流行类型的艺人", DataUtils.tagAgentId); SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
log.info("actualResult:{}", actualResult); expectedResult.setChatContext(expectedParseInfo);
expectedResult.setQueryMode(DetailDimensionQuery.QUERY_MODE);
expectedParseInfo.setQueryType(QueryType.DETAIL);
expectedParseInfo.setAggType(AggregateTypeEnum.NONE);
QueryFilter dimensionFilter = DataUtils.getFilter("singer_name", FilterOperatorEnum.EQUALS,
"周杰伦", "歌手名", 8L);
expectedParseInfo.getDimensionFilters().add(dimensionFilter);
expectedParseInfo.getDimensions().addAll(Lists.newArrayList(
SchemaElement.builder().name("流派").build(),
SchemaElement.builder().name("代表作").build()));
assertQueryResult(expectedResult, actualResult);
}
@Test
public void test_detail_id() throws Exception {
QueryResult actualResult = submitNewChat("周杰伦", DataUtils.tagAgentId);
QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
expectedResult.setChatContext(expectedParseInfo);
expectedResult.setQueryMode(DetailIdQuery.QUERY_MODE);
expectedParseInfo.setQueryType(QueryType.DETAIL);
expectedParseInfo.setAggType(AggregateTypeEnum.NONE);
QueryFilter dimensionFilter = DataUtils.getFilter("singer_name", FilterOperatorEnum.EQUALS,
"周杰伦", "歌手名", 8L);
expectedParseInfo.getDimensionFilters().add(dimensionFilter);
expectedParseInfo.getMetrics().add(SchemaElement.builder().name("播放量").build());
expectedParseInfo.getDimensions().addAll(Lists.newArrayList(
SchemaElement.builder().name("歌手名").build(),
SchemaElement.builder().name("活跃区域").build(),
SchemaElement.builder().name("流派").build(),
SchemaElement.builder().name("代表作").build()
));
assertQueryResult(expectedResult, actualResult);
}
@Test
public void test_detail_list_filter() throws Exception {
QueryResult actualResult = submitNewChat("国风艺人", DataUtils.tagAgentId);
QueryResult expectedResult = new QueryResult(); QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo(); SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
expectedResult.setChatContext(expectedParseInfo); expectedResult.setChatContext(expectedParseInfo);
expectedResult.setQueryMode(DetailFilterQuery.QUERY_MODE); expectedResult.setQueryMode(DetailFilterQuery.QUERY_MODE);
expectedParseInfo.setQueryType(QueryType.DETAIL);
expectedParseInfo.setAggType(AggregateTypeEnum.NONE); expectedParseInfo.setAggType(AggregateTypeEnum.NONE);
QueryFilter dimensionFilter = DataUtils.getFilter("genre", FilterOperatorEnum.EQUALS, QueryFilter dimensionFilter = DataUtils.getFilter("genre", FilterOperatorEnum.EQUALS,
"流行", "风格", 7L); "国风", "流派", 7L);
expectedParseInfo.getDimensionFilters().add(dimensionFilter); expectedParseInfo.getDimensionFilters().add(dimensionFilter);
SchemaElement metric = SchemaElement.builder().name("播放量").build(); expectedParseInfo.getMetrics().add(SchemaElement.builder().name("播放量").build());
expectedParseInfo.getMetrics().add(metric); expectedParseInfo.getDimensions().addAll(Lists.newArrayList(
SchemaElement.builder().name("歌手名").build(),
SchemaElement dim1 = SchemaElement.builder().name("歌手名").build(); SchemaElement.builder().name("活跃区域").build(),
SchemaElement dim2 = SchemaElement.builder().name("活跃区域").build(); SchemaElement.builder().name("流派").build(),
SchemaElement dim3 = SchemaElement.builder().name("风格").build(); SchemaElement.builder().name("代表作").build()
SchemaElement dim4 = SchemaElement.builder().name("代表作").build(); ));
expectedParseInfo.getDimensions().add(dim1);
expectedParseInfo.getDimensions().add(dim2);
expectedParseInfo.getDimensions().add(dim3);
expectedParseInfo.getDimensions().add(dim4);
expectedParseInfo.setQueryType(QueryType.DETAIL);
assertQueryResult(expectedResult, actualResult); assertQueryResult(expectedResult, actualResult);
} }