[improvement](project) Switch ENTITY to TAG uniformly in the queryType and semanticQuery (#420)

This commit is contained in:
lexluo09
2023-11-24 18:17:48 +08:00
committed by GitHub
parent d79e30cd7a
commit fe2a424718
17 changed files with 61 additions and 61 deletions

View File

@@ -4,16 +4,16 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.QueryType;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.pojo.QueryType;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
@@ -42,12 +42,12 @@ public class QueryTypeParser implements SemanticParser {
// 1.init S2SQL
semanticQuery.initS2Sql(user);
// 2.set queryType
QueryType queryType = getQueryType(user, semanticQuery);
QueryType queryType = getQueryType(semanticQuery);
semanticQuery.getParseInfo().setQueryType(queryType);
}
}
private QueryType getQueryType(User user, SemanticQuery semanticQuery) {
private QueryType getQueryType(SemanticQuery semanticQuery) {
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
SqlInfo sqlInfo = parseInfo.getSqlInfo();
if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getS2SQL())) {
@@ -55,18 +55,17 @@ public class QueryTypeParser implements SemanticParser {
}
//1. entity queryType
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof S2SQLQuery) {
// get primaryEntityBizName
//If all the fields in the SELECT statement are of tag type.
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, user);
if (Objects.nonNull(entityInfo) && Objects.nonNull(entityInfo.getModelInfo()) && StringUtils.isNotEmpty(
entityInfo.getModelInfo().getPrimaryEntityName())) {
String primaryEntityName = entityInfo.getModelInfo().getPrimaryEntityName();
//if exist primaryEntityName in S2SQL select.
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
boolean existPrimaryEntityName = selectFields.stream()
.anyMatch(fieldName -> primaryEntityName.equalsIgnoreCase(fieldName));
if (existPrimaryEntityName) {
return QueryType.ENTITY;
ModelSchema modelSchema = semanticService.getModelSchema(parseInfo.getModelId());
if (CollectionUtils.isNotEmpty(selectFields) && Objects.nonNull(modelSchema) && CollectionUtils.isNotEmpty(
modelSchema.getTags())) {
Set<String> tags = modelSchema.getTags().stream().map(schemaElement -> schemaElement.getName())
.collect(Collectors.toSet());
if (tags.containsAll(selectFields)) {
return QueryType.TAG;
}
}
}

View File

@@ -15,7 +15,7 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricTagQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricSemanticQuery;
import java.util.AbstractMap;
@@ -94,7 +94,7 @@ public class ContextInheritParser implements SemanticParser {
return matches.stream().anyMatch(m -> {
SchemaElementType type = m.getElement().getType();
if (Objects.nonNull(ruleQuery) && ruleQuery instanceof MetricSemanticQuery
&& !(ruleQuery instanceof MetricEntityQuery)) {
&& !(ruleQuery instanceof MetricTagQuery)) {
return types.contains(type);
}
return type.equals(matchType);

View File

@@ -4,7 +4,7 @@ import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.query.rule.entity.EntitySemanticQuery;
import com.tencent.supersonic.chat.query.rule.tag.TagSemanticQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricSemanticQuery;
import java.util.ArrayList;
import java.util.List;
@@ -83,7 +83,7 @@ public class QueryManager {
if (queryMode == null || !ruleQueryMap.containsKey(queryMode)) {
return false;
}
return ruleQueryMap.get(queryMode) instanceof EntitySemanticQuery;
return ruleQueryMap.get(queryMode) instanceof TagSemanticQuery;
}
public static boolean containsPluginQuery(String queryMode) {

View File

@@ -90,7 +90,7 @@ public class MetricInterpretQuery extends LLMSemanticQuery {
protected QueryStructReq convertQueryStruct() {
QueryStructReq queryStructReq = QueryReqBuilder.buildStructReq(parseInfo);
fillAggregator(queryStructReq, parseInfo.getMetrics());
queryStructReq.setQueryType(QueryType.ENTITY);
queryStructReq.setQueryType(QueryType.TAG);
return queryStructReq;
}

View File

@@ -23,11 +23,11 @@ import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNum
@Slf4j
@Component
public class MetricEntityQuery extends MetricSemanticQuery {
public class MetricTagQuery extends MetricSemanticQuery {
public static final String QUERY_MODE = "METRIC_ENTITY";
public static final String QUERY_MODE = "TAG_ENTITY";
public MetricEntityQuery() {
public MetricTagQuery() {
super();
queryMatcher.addOption(ID, REQUIRED, AT_LEAST, 1)
.addOption(ENTITY, REQUIRED, AT_LEAST, 1);

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.query.rule.entity;
package com.tencent.supersonic.chat.query.rule.tag;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.DIMENSION;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ID;
@@ -8,11 +8,11 @@ import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType
import org.springframework.stereotype.Component;
@Component
public class EntityDetailQuery extends EntitySemanticQuery {
public class TagDetailQuery extends TagSemanticQuery {
public static final String QUERY_MODE = "ENTITY_DETAIL";
public static final String QUERY_MODE = "TAG_DETAIL";
public EntityDetailQuery() {
public TagDetailQuery() {
super();
queryMatcher.addOption(DIMENSION, REQUIRED, AT_LEAST, 1)
.addOption(ID, REQUIRED, AT_LEAST, 1);

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.query.rule.entity;
package com.tencent.supersonic.chat.query.rule.tag;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.VALUE;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
@@ -10,11 +10,11 @@ import org.springframework.stereotype.Component;
@Slf4j
@Component
public class EntityFilterQuery extends EntityListQuery {
public class TagFilterQuery extends TagListQuery {
public static final String QUERY_MODE = "ENTITY_LIST_FILTER";
public static final String QUERY_MODE = "TAG_LIST_FILTER";
public EntityFilterQuery() {
public TagFilterQuery() {
super();
queryMatcher.addOption(VALUE, REQUIRED, AT_LEAST, 1);
}

View File

@@ -1,16 +1,17 @@
package com.tencent.supersonic.chat.query.rule.entity;
package com.tencent.supersonic.chat.query.rule.tag;
import org.springframework.stereotype.Component;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ID;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
import org.springframework.stereotype.Component;
@Component
public class EntityIdQuery extends EntityListQuery {
public class TagIdQuery extends TagListQuery {
public static final String QUERY_MODE = "ENTITY_ID";
public static final String QUERY_MODE = "TAG_ID";
public EntityIdQuery() {
public TagIdQuery() {
super();
queryMatcher.addOption(ID, REQUIRED, AT_LEAST, 1);
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.query.rule.entity;
package com.tencent.supersonic.chat.query.rule.tag;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
@@ -17,7 +17,7 @@ import java.util.Objects;
import java.util.Set;
import org.apache.commons.collections.CollectionUtils;
public abstract class EntityListQuery extends EntitySemanticQuery {
public abstract class TagListQuery extends TagSemanticQuery {
@Override
public void fillParseInfo(Long modelId, QueryContext queryContext, ChatContext chatContext) {

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.query.rule.entity;
package com.tencent.supersonic.chat.query.rule.tag;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ENTITY;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
@@ -24,11 +24,11 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
@Slf4j
public abstract class EntitySemanticQuery extends RuleSemanticQuery {
public abstract class TagSemanticQuery extends RuleSemanticQuery {
private static final Long ENTITY_MAX_RESULTS = 500L;
private static final Long TAG_MAX_RESULTS = 500L;
public EntitySemanticQuery() {
public TagSemanticQuery() {
super();
queryMatcher.addOption(ENTITY, REQUIRED, AT_LEAST, 1);
}
@@ -81,8 +81,8 @@ public abstract class EntitySemanticQuery extends RuleSemanticQuery {
public void fillParseInfo(Long modelId, QueryContext queryContext, ChatContext chatContext) {
super.fillParseInfo(modelId, queryContext, chatContext);
parseInfo.setQueryType(QueryType.ENTITY);
parseInfo.setLimit(ENTITY_MAX_RESULTS);
parseInfo.setQueryType(QueryType.TAG);
parseInfo.setLimit(TAG_MAX_RESULTS);
if (parseInfo.getDateInfo() == null) {
ConfigService configService = ContextUtils.getBean(ConfigService.class);
ChatConfigRichResp chatConfig = configService.getConfigRichInfo(parseInfo.getModelId());

View File

@@ -234,7 +234,7 @@ public class SemanticService {
modelInfo.setEntityId(entities.get(0));
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
semanticParseInfo.setModel(modelSchema.getModel());
semanticParseInfo.setQueryType(QueryType.ENTITY);
semanticParseInfo.setQueryType(QueryType.TAG);
semanticParseInfo.setMetrics(getMetrics(modelInfo));
semanticParseInfo.setDimensions(getDimensions(modelInfo));
DateConf dateInfo = new DateConf();

View File

@@ -86,7 +86,7 @@ public class ParserInfoServiceImpl implements ParseInfoService {
parseInfo.setDimensions(
getElements(parseInfo.getModelId(), groupByDimensions, semanticSchema.getDimensions()));
} else {
parseInfo.setQueryType(QueryType.ENTITY);
parseInfo.setQueryType(QueryType.TAG);
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getCorrectS2SQL());
List<String> selectDimensions = getFieldsExceptDate(selectFields);
parseInfo.setDimensions(

View File

@@ -56,7 +56,7 @@ class QueryReqBuilderTest {
"SELECT department, SUM(pv) FROM 内容库 WHERE (sys_imp_date IN ('2023-08-01')) "
+ "GROUP BY department ORDER BY uv LIMIT 2000", queryS2SQLReq.getSql());
queryStructReq.setQueryType(QueryType.ENTITY);
queryStructReq.setQueryType(QueryType.TAG);
queryS2SQLReq = queryStructReq.convert(queryStructReq);
Assert.assertEquals(
"SELECT department, pv FROM 内容库 WHERE (sys_imp_date IN ('2023-08-01')) "

View File

@@ -9,15 +9,15 @@ public enum QueryType {
*/
METRIC,
/**
* queries with entity unique key included in the select statement
* queries with only tag included in the select statement
*/
ENTITY,
TAG,
/**
* the other queries
*/
OTHER;
public boolean isNativeAggQuery() {
return ENTITY.equals(this);
return TAG.equals(this);
}
}

View File

@@ -6,8 +6,8 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.query.rule.entity.EntityFilterQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery;
import com.tencent.supersonic.chat.query.rule.tag.TagFilterQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricTagQuery;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
import com.tencent.supersonic.common.pojo.QueryType;
@@ -17,17 +17,17 @@ import java.util.ArrayList;
import java.util.List;
import org.junit.Test;
public class EntityQueryTest extends BaseQueryTest {
public class TagQueryTest extends BaseQueryTest {
@Test
public void queryTest_metric_entity_query() throws Exception {
public void queryTest_metric_tag_query() throws Exception {
QueryResult actualResult = submitNewChat("艺人周杰伦的播放量");
QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
expectedResult.setChatContext(expectedParseInfo);
expectedResult.setQueryMode(MetricEntityQuery.QUERY_MODE);
expectedResult.setQueryMode(MetricTagQuery.QUERY_MODE);
expectedParseInfo.setAggType(NONE);
QueryFilter dimensionFilter = DataUtils.getFilter("singer_name", FilterOperatorEnum.EQUALS, "周杰伦", "歌手名", 7L);
@@ -43,14 +43,14 @@ public class EntityQueryTest extends BaseQueryTest {
}
@Test
public void queryTest_entity_list_filter() throws Exception {
public void queryTest_tag_list_filter() throws Exception {
QueryResult actualResult = submitNewChat("爱情、流行类型的艺人");
QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
expectedResult.setChatContext(expectedParseInfo);
expectedResult.setQueryMode(EntityFilterQuery.QUERY_MODE);
expectedResult.setQueryMode(TagFilterQuery.QUERY_MODE);
expectedParseInfo.setAggType(NONE);
List<String> list = new ArrayList<>();
@@ -72,7 +72,7 @@ public class EntityQueryTest extends BaseQueryTest {
expectedParseInfo.getDimensions().add(dim4);
expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.BETWEEN, startDay, startDay));
expectedParseInfo.setQueryType(QueryType.ENTITY);
expectedParseInfo.setQueryType(QueryType.TAG);
assertQueryResult(expectedResult, actualResult);
}

View File

@@ -7,7 +7,7 @@ import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricTagQuery;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.QueryType;
import com.tencent.supersonic.integration.BaseQueryTest;
@@ -29,7 +29,7 @@ public class MapperTest extends BaseQueryTest {
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
expectedResult.setChatContext(expectedParseInfo);
expectedResult.setQueryMode(MetricEntityQuery.QUERY_MODE);
expectedResult.setQueryMode(MetricTagQuery.QUERY_MODE);
expectedParseInfo.setAggType(NONE);
QueryFilter dimensionFilter = DataUtils.getFilter("singer_name", FilterOperatorEnum.EQUALS, "周杰伦", "歌手名", 7L);

View File

@@ -211,7 +211,7 @@ public class QueryReqConverter {
private QueryType getQueryType(AggOption aggOption) {
boolean isAgg = AggOption.isAgg(aggOption);
QueryType queryType = QueryType.ENTITY;
QueryType queryType = QueryType.TAG;
if (isAgg) {
queryType = QueryType.METRIC;
}