mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
[improvement](project) Switch ENTITY to TAG uniformly in the queryType and semanticQuery (#420)
This commit is contained in:
@@ -4,16 +4,16 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.pojo.QueryType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.common.pojo.QueryType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
@@ -42,12 +42,12 @@ public class QueryTypeParser implements SemanticParser {
|
||||
// 1.init S2SQL
|
||||
semanticQuery.initS2Sql(user);
|
||||
// 2.set queryType
|
||||
QueryType queryType = getQueryType(user, semanticQuery);
|
||||
QueryType queryType = getQueryType(semanticQuery);
|
||||
semanticQuery.getParseInfo().setQueryType(queryType);
|
||||
}
|
||||
}
|
||||
|
||||
private QueryType getQueryType(User user, SemanticQuery semanticQuery) {
|
||||
private QueryType getQueryType(SemanticQuery semanticQuery) {
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getS2SQL())) {
|
||||
@@ -55,18 +55,17 @@ public class QueryTypeParser implements SemanticParser {
|
||||
}
|
||||
//1. entity queryType
|
||||
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof S2SQLQuery) {
|
||||
// get primaryEntityBizName
|
||||
//If all the fields in the SELECT statement are of tag type.
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
||||
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, user);
|
||||
if (Objects.nonNull(entityInfo) && Objects.nonNull(entityInfo.getModelInfo()) && StringUtils.isNotEmpty(
|
||||
entityInfo.getModelInfo().getPrimaryEntityName())) {
|
||||
String primaryEntityName = entityInfo.getModelInfo().getPrimaryEntityName();
|
||||
//if exist primaryEntityName in S2SQL select.
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||
boolean existPrimaryEntityName = selectFields.stream()
|
||||
.anyMatch(fieldName -> primaryEntityName.equalsIgnoreCase(fieldName));
|
||||
if (existPrimaryEntityName) {
|
||||
return QueryType.ENTITY;
|
||||
ModelSchema modelSchema = semanticService.getModelSchema(parseInfo.getModelId());
|
||||
|
||||
if (CollectionUtils.isNotEmpty(selectFields) && Objects.nonNull(modelSchema) && CollectionUtils.isNotEmpty(
|
||||
modelSchema.getTags())) {
|
||||
Set<String> tags = modelSchema.getTags().stream().map(schemaElement -> schemaElement.getName())
|
||||
.collect(Collectors.toSet());
|
||||
if (tags.containsAll(selectFields)) {
|
||||
return QueryType.TAG;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricTagQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricSemanticQuery;
|
||||
import java.util.AbstractMap;
|
||||
@@ -94,7 +94,7 @@ public class ContextInheritParser implements SemanticParser {
|
||||
return matches.stream().anyMatch(m -> {
|
||||
SchemaElementType type = m.getElement().getType();
|
||||
if (Objects.nonNull(ruleQuery) && ruleQuery instanceof MetricSemanticQuery
|
||||
&& !(ruleQuery instanceof MetricEntityQuery)) {
|
||||
&& !(ruleQuery instanceof MetricTagQuery)) {
|
||||
return types.contains(type);
|
||||
}
|
||||
return type.equals(matchType);
|
||||
|
||||
@@ -4,7 +4,7 @@ import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.entity.EntitySemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.tag.TagSemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricSemanticQuery;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
@@ -83,7 +83,7 @@ public class QueryManager {
|
||||
if (queryMode == null || !ruleQueryMap.containsKey(queryMode)) {
|
||||
return false;
|
||||
}
|
||||
return ruleQueryMap.get(queryMode) instanceof EntitySemanticQuery;
|
||||
return ruleQueryMap.get(queryMode) instanceof TagSemanticQuery;
|
||||
}
|
||||
|
||||
public static boolean containsPluginQuery(String queryMode) {
|
||||
|
||||
@@ -90,7 +90,7 @@ public class MetricInterpretQuery extends LLMSemanticQuery {
|
||||
protected QueryStructReq convertQueryStruct() {
|
||||
QueryStructReq queryStructReq = QueryReqBuilder.buildStructReq(parseInfo);
|
||||
fillAggregator(queryStructReq, parseInfo.getMetrics());
|
||||
queryStructReq.setQueryType(QueryType.ENTITY);
|
||||
queryStructReq.setQueryType(QueryType.TAG);
|
||||
return queryStructReq;
|
||||
}
|
||||
|
||||
|
||||
@@ -23,11 +23,11 @@ import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNum
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class MetricEntityQuery extends MetricSemanticQuery {
|
||||
public class MetricTagQuery extends MetricSemanticQuery {
|
||||
|
||||
public static final String QUERY_MODE = "METRIC_ENTITY";
|
||||
public static final String QUERY_MODE = "TAG_ENTITY";
|
||||
|
||||
public MetricEntityQuery() {
|
||||
public MetricTagQuery() {
|
||||
super();
|
||||
queryMatcher.addOption(ID, REQUIRED, AT_LEAST, 1)
|
||||
.addOption(ENTITY, REQUIRED, AT_LEAST, 1);
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.query.rule.entity;
|
||||
package com.tencent.supersonic.chat.query.rule.tag;
|
||||
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.DIMENSION;
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ID;
|
||||
@@ -8,11 +8,11 @@ import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Component
|
||||
public class EntityDetailQuery extends EntitySemanticQuery {
|
||||
public class TagDetailQuery extends TagSemanticQuery {
|
||||
|
||||
public static final String QUERY_MODE = "ENTITY_DETAIL";
|
||||
public static final String QUERY_MODE = "TAG_DETAIL";
|
||||
|
||||
public EntityDetailQuery() {
|
||||
public TagDetailQuery() {
|
||||
super();
|
||||
queryMatcher.addOption(DIMENSION, REQUIRED, AT_LEAST, 1)
|
||||
.addOption(ID, REQUIRED, AT_LEAST, 1);
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.query.rule.entity;
|
||||
package com.tencent.supersonic.chat.query.rule.tag;
|
||||
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.VALUE;
|
||||
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
@@ -10,11 +10,11 @@ import org.springframework.stereotype.Component;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class EntityFilterQuery extends EntityListQuery {
|
||||
public class TagFilterQuery extends TagListQuery {
|
||||
|
||||
public static final String QUERY_MODE = "ENTITY_LIST_FILTER";
|
||||
public static final String QUERY_MODE = "TAG_LIST_FILTER";
|
||||
|
||||
public EntityFilterQuery() {
|
||||
public TagFilterQuery() {
|
||||
super();
|
||||
queryMatcher.addOption(VALUE, REQUIRED, AT_LEAST, 1);
|
||||
}
|
||||
@@ -1,16 +1,17 @@
|
||||
package com.tencent.supersonic.chat.query.rule.entity;
|
||||
package com.tencent.supersonic.chat.query.rule.tag;
|
||||
|
||||
import org.springframework.stereotype.Component;
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ID;
|
||||
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Component
|
||||
public class EntityIdQuery extends EntityListQuery {
|
||||
public class TagIdQuery extends TagListQuery {
|
||||
|
||||
public static final String QUERY_MODE = "ENTITY_ID";
|
||||
public static final String QUERY_MODE = "TAG_ID";
|
||||
|
||||
public EntityIdQuery() {
|
||||
public TagIdQuery() {
|
||||
super();
|
||||
queryMatcher.addOption(ID, REQUIRED, AT_LEAST, 1);
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.query.rule.entity;
|
||||
package com.tencent.supersonic.chat.query.rule.tag;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
@@ -17,7 +17,7 @@ import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
public abstract class EntityListQuery extends EntitySemanticQuery {
|
||||
public abstract class TagListQuery extends TagSemanticQuery {
|
||||
|
||||
@Override
|
||||
public void fillParseInfo(Long modelId, QueryContext queryContext, ChatContext chatContext) {
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.query.rule.entity;
|
||||
package com.tencent.supersonic.chat.query.rule.tag;
|
||||
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ENTITY;
|
||||
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
@@ -24,11 +24,11 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public abstract class EntitySemanticQuery extends RuleSemanticQuery {
|
||||
public abstract class TagSemanticQuery extends RuleSemanticQuery {
|
||||
|
||||
private static final Long ENTITY_MAX_RESULTS = 500L;
|
||||
private static final Long TAG_MAX_RESULTS = 500L;
|
||||
|
||||
public EntitySemanticQuery() {
|
||||
public TagSemanticQuery() {
|
||||
super();
|
||||
queryMatcher.addOption(ENTITY, REQUIRED, AT_LEAST, 1);
|
||||
}
|
||||
@@ -81,8 +81,8 @@ public abstract class EntitySemanticQuery extends RuleSemanticQuery {
|
||||
public void fillParseInfo(Long modelId, QueryContext queryContext, ChatContext chatContext) {
|
||||
super.fillParseInfo(modelId, queryContext, chatContext);
|
||||
|
||||
parseInfo.setQueryType(QueryType.ENTITY);
|
||||
parseInfo.setLimit(ENTITY_MAX_RESULTS);
|
||||
parseInfo.setQueryType(QueryType.TAG);
|
||||
parseInfo.setLimit(TAG_MAX_RESULTS);
|
||||
if (parseInfo.getDateInfo() == null) {
|
||||
ConfigService configService = ContextUtils.getBean(ConfigService.class);
|
||||
ChatConfigRichResp chatConfig = configService.getConfigRichInfo(parseInfo.getModelId());
|
||||
@@ -234,7 +234,7 @@ public class SemanticService {
|
||||
modelInfo.setEntityId(entities.get(0));
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
semanticParseInfo.setModel(modelSchema.getModel());
|
||||
semanticParseInfo.setQueryType(QueryType.ENTITY);
|
||||
semanticParseInfo.setQueryType(QueryType.TAG);
|
||||
semanticParseInfo.setMetrics(getMetrics(modelInfo));
|
||||
semanticParseInfo.setDimensions(getDimensions(modelInfo));
|
||||
DateConf dateInfo = new DateConf();
|
||||
|
||||
@@ -86,7 +86,7 @@ public class ParserInfoServiceImpl implements ParseInfoService {
|
||||
parseInfo.setDimensions(
|
||||
getElements(parseInfo.getModelId(), groupByDimensions, semanticSchema.getDimensions()));
|
||||
} else {
|
||||
parseInfo.setQueryType(QueryType.ENTITY);
|
||||
parseInfo.setQueryType(QueryType.TAG);
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getCorrectS2SQL());
|
||||
List<String> selectDimensions = getFieldsExceptDate(selectFields);
|
||||
parseInfo.setDimensions(
|
||||
|
||||
@@ -56,7 +56,7 @@ class QueryReqBuilderTest {
|
||||
"SELECT department, SUM(pv) FROM 内容库 WHERE (sys_imp_date IN ('2023-08-01')) "
|
||||
+ "GROUP BY department ORDER BY uv LIMIT 2000", queryS2SQLReq.getSql());
|
||||
|
||||
queryStructReq.setQueryType(QueryType.ENTITY);
|
||||
queryStructReq.setQueryType(QueryType.TAG);
|
||||
queryS2SQLReq = queryStructReq.convert(queryStructReq);
|
||||
Assert.assertEquals(
|
||||
"SELECT department, pv FROM 内容库 WHERE (sys_imp_date IN ('2023-08-01')) "
|
||||
|
||||
@@ -9,15 +9,15 @@ public enum QueryType {
|
||||
*/
|
||||
METRIC,
|
||||
/**
|
||||
* queries with entity unique key included in the select statement
|
||||
* queries with only tag included in the select statement
|
||||
*/
|
||||
ENTITY,
|
||||
TAG,
|
||||
/**
|
||||
* the other queries
|
||||
*/
|
||||
OTHER;
|
||||
|
||||
public boolean isNativeAggQuery() {
|
||||
return ENTITY.equals(this);
|
||||
return TAG.equals(this);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,8 +6,8 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.query.rule.entity.EntityFilterQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.tag.TagFilterQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricTagQuery;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
|
||||
import com.tencent.supersonic.common.pojo.QueryType;
|
||||
@@ -17,17 +17,17 @@ import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.junit.Test;
|
||||
|
||||
public class EntityQueryTest extends BaseQueryTest {
|
||||
public class TagQueryTest extends BaseQueryTest {
|
||||
|
||||
@Test
|
||||
public void queryTest_metric_entity_query() throws Exception {
|
||||
public void queryTest_metric_tag_query() throws Exception {
|
||||
QueryResult actualResult = submitNewChat("艺人周杰伦的播放量");
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
||||
expectedResult.setChatContext(expectedParseInfo);
|
||||
|
||||
expectedResult.setQueryMode(MetricEntityQuery.QUERY_MODE);
|
||||
expectedResult.setQueryMode(MetricTagQuery.QUERY_MODE);
|
||||
expectedParseInfo.setAggType(NONE);
|
||||
|
||||
QueryFilter dimensionFilter = DataUtils.getFilter("singer_name", FilterOperatorEnum.EQUALS, "周杰伦", "歌手名", 7L);
|
||||
@@ -43,14 +43,14 @@ public class EntityQueryTest extends BaseQueryTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void queryTest_entity_list_filter() throws Exception {
|
||||
public void queryTest_tag_list_filter() throws Exception {
|
||||
QueryResult actualResult = submitNewChat("爱情、流行类型的艺人");
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
||||
expectedResult.setChatContext(expectedParseInfo);
|
||||
|
||||
expectedResult.setQueryMode(EntityFilterQuery.QUERY_MODE);
|
||||
expectedResult.setQueryMode(TagFilterQuery.QUERY_MODE);
|
||||
expectedParseInfo.setAggType(NONE);
|
||||
|
||||
List<String> list = new ArrayList<>();
|
||||
@@ -72,7 +72,7 @@ public class EntityQueryTest extends BaseQueryTest {
|
||||
expectedParseInfo.getDimensions().add(dim4);
|
||||
|
||||
expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.BETWEEN, startDay, startDay));
|
||||
expectedParseInfo.setQueryType(QueryType.ENTITY);
|
||||
expectedParseInfo.setQueryType(QueryType.TAG);
|
||||
|
||||
assertQueryResult(expectedResult, actualResult);
|
||||
}
|
||||
@@ -7,7 +7,7 @@ import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricTagQuery;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.QueryType;
|
||||
import com.tencent.supersonic.integration.BaseQueryTest;
|
||||
@@ -29,7 +29,7 @@ public class MapperTest extends BaseQueryTest {
|
||||
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
||||
expectedResult.setChatContext(expectedParseInfo);
|
||||
|
||||
expectedResult.setQueryMode(MetricEntityQuery.QUERY_MODE);
|
||||
expectedResult.setQueryMode(MetricTagQuery.QUERY_MODE);
|
||||
expectedParseInfo.setAggType(NONE);
|
||||
|
||||
QueryFilter dimensionFilter = DataUtils.getFilter("singer_name", FilterOperatorEnum.EQUALS, "周杰伦", "歌手名", 7L);
|
||||
|
||||
@@ -211,7 +211,7 @@ public class QueryReqConverter {
|
||||
|
||||
private QueryType getQueryType(AggOption aggOption) {
|
||||
boolean isAgg = AggOption.isAgg(aggOption);
|
||||
QueryType queryType = QueryType.ENTITY;
|
||||
QueryType queryType = QueryType.TAG;
|
||||
if (isAgg) {
|
||||
queryType = QueryType.METRIC;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user