[improvement][headless]Deprecate and remove entity-related abstraction and logic.#1876

This commit is contained in:
jerryjzhang
2024-11-04 00:55:07 +08:00
parent 6a4458a572
commit 1e5bf7909e
49 changed files with 61 additions and 1081 deletions

View File

@@ -1,13 +1,9 @@
package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import lombok.extern.slf4j.Slf4j;
@@ -18,9 +14,7 @@ import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
/** Perform SQL corrections on the "Select" section in S2SQL. */
@Slf4j
@@ -42,13 +36,11 @@ public class SelectCorrector extends BaseSemanticCorrector {
&& aggregateFields.size() == selectFields.size()) {
return;
}
correctS2SQL = addFieldsToSelect(chatQueryContext, semanticParseInfo, correctS2SQL);
correctS2SQL = addFieldsToSelect(semanticParseInfo, correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
}
protected String addFieldsToSelect(ChatQueryContext chatQueryContext,
SemanticParseInfo semanticParseInfo, String correctS2SQL) {
correctS2SQL = addTagDefaultFields(chatQueryContext, semanticParseInfo, correctS2SQL);
protected String addFieldsToSelect(SemanticParseInfo semanticParseInfo, String correctS2SQL) {
Set<String> selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL));
Set<String> needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL));
@@ -70,34 +62,4 @@ public class SelectCorrector extends BaseSemanticCorrector {
return addFieldsToSelectSql;
}
private String addTagDefaultFields(ChatQueryContext chatQueryContext,
SemanticParseInfo semanticParseInfo, String correctS2SQL) {
// If it is in DETAIL mode and select *, add default metrics and dimensions.
boolean hasAsterisk = SqlSelectFunctionHelper.hasAsterisk(correctS2SQL);
if (!(hasAsterisk && QueryType.DETAIL.equals(semanticParseInfo.getQueryType()))) {
return correctS2SQL;
}
Long dataSetId = semanticParseInfo.getDataSetId();
DataSetSchema dataSetSchema =
chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
Set<String> needAddDefaultFields = new HashSet<>();
if (Objects.nonNull(dataSetSchema)) {
if (!CollectionUtils.isEmpty(dataSetSchema.getTagDefaultMetrics())) {
Set<String> metrics = dataSetSchema.getTagDefaultMetrics().stream()
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
needAddDefaultFields.addAll(metrics);
}
if (!CollectionUtils.isEmpty(dataSetSchema.getTagDefaultDimensions())) {
Set<String> dimensions = dataSetSchema.getTagDefaultDimensions().stream()
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
needAddDefaultFields.addAll(dimensions);
}
}
// remove * in sql and add default fields.
if (!CollectionUtils.isEmpty(needAddDefaultFields)) {
correctS2SQL =
SqlRemoveHelper.removeAsteriskAndAddFields(correctS2SQL, needAddDefaultFields);
}
return correctS2SQL;
}
}

View File

@@ -35,9 +35,6 @@ public class NatureHelper {
case DIMENSION:
result = SchemaElementType.DIMENSION;
break;
case ENTITY:
result = SchemaElementType.ENTITY;
break;
case DATASET:
result = SchemaElementType.DATASET;
break;

View File

@@ -1,77 +0,0 @@
package com.tencent.supersonic.headless.chat.mapper;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.stream.Collectors;
/** A mapper capable of converting the VALUE of entity dimension values into ID types. */
@Slf4j
public class EntityMapper extends BaseMapper {
@Override
public void doMap(ChatQueryContext chatQueryContext) {
SchemaMapInfo schemaMapInfo = chatQueryContext.getMapInfo();
for (Long dataSetId : schemaMapInfo.getMatchedDataSetInfos()) {
List<SchemaElementMatch> schemaElementMatchList =
schemaMapInfo.getMatchedElements(dataSetId);
if (CollectionUtils.isEmpty(schemaElementMatchList)) {
continue;
}
SchemaElement entity = getEntity(dataSetId, chatQueryContext);
if (entity == null || entity.getId() == null) {
continue;
}
List<SchemaElementMatch> valueSchemaElements = schemaElementMatchList.stream()
.filter(schemaElementMatch -> SchemaElementType.VALUE
.equals(schemaElementMatch.getElement().getType()))
.collect(Collectors.toList());
for (SchemaElementMatch schemaElementMatch : valueSchemaElements) {
if (!entity.getId().equals(schemaElementMatch.getElement().getId())) {
continue;
}
if (!checkExistSameEntitySchemaElements(schemaElementMatch,
schemaElementMatchList)) {
SchemaElementMatch entitySchemaElementMath = new SchemaElementMatch();
BeanUtils.copyProperties(schemaElementMatch, entitySchemaElementMath);
entitySchemaElementMath.setElement(entity);
schemaElementMatchList.add(entitySchemaElementMath);
}
schemaElementMatch.getElement().setType(SchemaElementType.ID);
}
}
}
private boolean checkExistSameEntitySchemaElements(SchemaElementMatch valueSchemaElementMatch,
List<SchemaElementMatch> schemaElementMatchList) {
List<SchemaElementMatch> entitySchemaElements = schemaElementMatchList.stream()
.filter(schemaElementMatch -> SchemaElementType.ENTITY
.equals(schemaElementMatch.getElement().getType()))
.collect(Collectors.toList());
for (SchemaElementMatch schemaElementMatch : entitySchemaElements) {
if (schemaElementMatch.getElement().getId()
.equals(valueSchemaElementMatch.getElement().getId())) {
return true;
}
}
return false;
}
private SchemaElement getEntity(Long dataSetId, ChatQueryContext chatQueryContext) {
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
DataSetSchema modelSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId);
if (modelSchema != null && modelSchema.getEntity() != null) {
return modelSchema.getEntity();
}
return null;
}
}

View File

@@ -94,9 +94,8 @@ public class MapFilter {
SchemaElement element = schemaElementMatch.getElement();
SchemaElementType type = element.getType();
boolean isEntityOrDatasetOrId = SchemaElementType.ENTITY.equals(type)
|| SchemaElementType.DATASET.equals(type)
|| SchemaElementType.ID.equals(type);
boolean isEntityOrDatasetOrId =
SchemaElementType.DATASET.equals(type) || SchemaElementType.ID.equals(type);
return !isEntityOrDatasetOrId && needRemovePredicate.test(element);
});

View File

@@ -1,29 +1,20 @@
package com.tencent.supersonic.headless.chat.parser;
import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
/** QueryTypeParser resolves query type as either AGGREGATE or DETAIL or ID. */
/** QueryTypeParser resolves query type as either AGGREGATE or DETAIL */
@Slf4j
public class QueryTypeParser implements SemanticParser {
@@ -52,22 +43,6 @@ public class QueryTypeParser implements SemanticParser {
return QueryType.DETAIL;
}
// 1. entity queryType
Long dataSetId = parseInfo.getDataSetId();
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getParsedS2SQL());
List<String> whereFilterByTimeFields = filterByTimeFields(whereFields);
if (CollectionUtils.isNotEmpty(whereFilterByTimeFields)) {
Set<String> ids = semanticSchema.getEntities(dataSetId).stream()
.map(SchemaElement::getName).collect(Collectors.toSet());
if (CollectionUtils.isNotEmpty(ids)
&& ids.stream().anyMatch(whereFilterByTimeFields::contains)) {
return QueryType.ID;
}
}
}
// 2. AGG queryType
if (SqlSelectFunctionHelper.hasAggregateFunction(sqlInfo.getParsedS2SQL())) {
return QueryType.AGGREGATE;
@@ -76,20 +51,4 @@ public class QueryTypeParser implements SemanticParser {
return QueryType.DETAIL;
}
private static List<String> filterByTimeFields(List<String> whereFields) {
return whereFields.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
.collect(Collectors.toList());
}
private static boolean selectContainsMetric(SqlInfo sqlInfo, Long dataSetId,
SemanticSchema semanticSchema) {
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getParsedS2SQL());
List<SchemaElement> metrics = semanticSchema.getMetrics(dataSetId);
if (CollectionUtils.isNotEmpty(metrics)) {
Set<String> metricNameSet =
metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
return selectFields.stream().anyMatch(metricNameSet::contains);
}
return false;
}
}

View File

@@ -10,7 +10,6 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.query.BaseSemanticQuery;
@@ -124,18 +123,6 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
SchemaElement element = schemaMatch.getElement();
element.setOrder(1 - schemaMatch.getSimilarity());
switch (element.getType()) {
case ID:
SchemaElement entityElement =
semanticSchema.getElement(SchemaElementType.ENTITY, element.getId());
if (entityElement != null) {
if (id2Values.containsKey(element.getId())) {
id2Values.get(element.getId()).add(schemaMatch);
} else {
id2Values.put(element.getId(),
new ArrayList<>(Arrays.asList(schemaMatch)));
}
}
break;
case VALUE:
SchemaElement dimElement =
semanticSchema.getElement(SchemaElementType.DIMENSION, element.getId());
@@ -154,13 +141,9 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
case METRIC:
parseInfo.getMetrics().add(element);
break;
case ENTITY:
parseInfo.setEntity(element);
break;
default:
}
}
addToFilters(id2Values, parseInfo, semanticSchema, SchemaElementType.ENTITY);
addToFilters(dim2Values, parseInfo, semanticSchema, SchemaElementType.DIMENSION);
}
@@ -182,8 +165,6 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
dimensionFilter.setName(dimension.getName());
dimensionFilter.setOperator(FilterOperatorEnum.EQUALS);
dimensionFilter.setElementID(schemaMatch.getElement().getId());
parseInfo.setEntity(
semanticSchema.getElement(SchemaElementType.ENTITY, entry.getKey()));
parseInfo.getDimensionFilters().add(dimensionFilter);
} else {
QueryFilter dimensionFilter = new QueryFilter();
@@ -216,11 +197,6 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
return convertQueryMultiStruct();
}
@Override
public void setParseInfo(SemanticParseInfo parseInfo) {
this.parseInfo = parseInfo;
}
public static List<RuleSemanticQuery> resolve(Long dataSetId,
List<SchemaElementMatch> candidateElementMatches, ChatQueryContext chatQueryContext) {
List<RuleSemanticQuery> matchedQueries = new ArrayList<>();
@@ -228,7 +204,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
List<SchemaElementMatch> matches =
semanticQuery.match(candidateElementMatches, chatQueryContext);
if (matches.size() > 0) {
if (!matches.isEmpty()) {
RuleSemanticQuery query =
QueryManager.createRuleQuery(semanticQuery.getQueryMode());
query.getParseInfo().getElementMatches().addAll(matches);
@@ -238,10 +214,6 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
return matchedQueries;
}
protected QueryStructReq convertQueryStruct() {
return QueryReqBuilder.buildStructReq(parseInfo);
}
protected QueryMultiStructReq convertQueryMultiStruct() {
return QueryReqBuilder.buildMultiStructReq(parseInfo);
}

View File

@@ -1,27 +0,0 @@
package com.tencent.supersonic.headless.chat.query.rule.detail;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ENTITY;
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.VALUE;
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
@Slf4j
@Component
public class DetailFilterQuery extends DetailListQuery {
public static final String QUERY_MODE = "DETAIL_LIST_FILTER";
public DetailFilterQuery() {
super();
queryMatcher.addOption(VALUE, REQUIRED, AT_LEAST, 1);
queryMatcher.addOption(ENTITY, REQUIRED, AT_LEAST, 1);
}
@Override
public String getQueryMode() {
return QUERY_MODE;
}
}

View File

@@ -1,23 +0,0 @@
package com.tencent.supersonic.headless.chat.query.rule.detail;
import org.springframework.stereotype.Component;
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ID;
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
@Component
public class DetailIdQuery extends DetailListQuery {
public static final String QUERY_MODE = "DETAIL_ID";
public DetailIdQuery() {
super();
queryMatcher.addOption(ID, REQUIRED, AT_LEAST, 1);
}
@Override
public String getQueryMode() {
return QUERY_MODE;
}
}

View File

@@ -1,65 +0,0 @@
package com.tencent.supersonic.headless.chat.query.rule.detail;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.headless.api.pojo.*;
import com.tencent.supersonic.headless.api.pojo.DetailTypeDefaultConfig;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import org.apache.commons.collections.CollectionUtils;
import java.util.LinkedHashSet;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
public abstract class DetailListQuery extends DetailSemanticQuery {
@Override
public void fillParseInfo(ChatQueryContext chatQueryContext) {
super.fillParseInfo(chatQueryContext);
this.addEntityDetailAndOrderByMetric(chatQueryContext, parseInfo);
}
private void addEntityDetailAndOrderByMetric(ChatQueryContext chatQueryContext,
SemanticParseInfo parseInfo) {
Long dataSetId = parseInfo.getDataSetId();
if (Objects.isNull(dataSetId) || dataSetId <= 0L) {
return;
}
DataSetSchema dataSetSchema =
chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
if (dataSetSchema != null && Objects.nonNull(dataSetSchema.getEntity())) {
Set<SchemaElement> dimensions = new LinkedHashSet<>();
Set<SchemaElement> metrics = new LinkedHashSet<>();
Set<Order> orders = new LinkedHashSet<>();
DetailTypeDefaultConfig detailTypeDefaultConfig =
dataSetSchema.getTagTypeDefaultConfig();
if (detailTypeDefaultConfig != null
&& detailTypeDefaultConfig.getDefaultDisplayInfo() != null) {
if (CollectionUtils.isNotEmpty(
detailTypeDefaultConfig.getDefaultDisplayInfo().getMetricIds())) {
metrics = detailTypeDefaultConfig.getDefaultDisplayInfo().getMetricIds()
.stream().map(id -> {
SchemaElement metric =
dataSetSchema.getElement(SchemaElementType.METRIC, id);
if (metric != null) {
orders.add(
new Order(metric.getBizName(), Constants.DESC_UPPER));
}
return metric;
}).filter(Objects::nonNull).collect(Collectors.toSet());
}
if (CollectionUtils.isNotEmpty(
detailTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds())) {
dimensions = detailTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds()
.stream()
.map(id -> dataSetSchema.getElement(SchemaElementType.DIMENSION, id))
.filter(Objects::nonNull).collect(Collectors.toSet());
}
}
parseInfo.setDimensions(dimensions);
parseInfo.setMetrics(metrics);
parseInfo.setOrders(orders);
}
}
}

View File

@@ -41,7 +41,7 @@ public abstract class DetailSemanticQuery extends RuleSemanticQuery {
Map<Long, DataSetSchema> dataSetSchemaMap =
chatQueryContext.getSemanticSchema().getDataSetSchemaMap();
DataSetSchema dataSetSchema = dataSetSchemaMap.get(parseInfo.getDataSetId());
TimeDefaultConfig timeDefaultConfig = dataSetSchema.getTagTypeTimeDefaultConfig();
TimeDefaultConfig timeDefaultConfig = dataSetSchema.getDetailTypeTimeDefaultConfig();
if (Objects.nonNull(timeDefaultConfig) && Objects.nonNull(timeDefaultConfig.getUnit())
&& timeDefaultConfig.getUnit() != -1) {

View File

@@ -1,89 +0,0 @@
package com.tencent.supersonic.headless.chat.query.rule.metric;
import com.tencent.supersonic.common.pojo.Filter;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.FilterType;
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ENTITY;
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ID;
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
@Slf4j
@Component
public class MetricIdQuery extends MetricSemanticQuery {
public static final String QUERY_MODE = "METRIC_ID";
public MetricIdQuery() {
super();
queryMatcher.addOption(ID, REQUIRED, AT_LEAST, 1).addOption(ENTITY, REQUIRED, AT_LEAST, 1);
}
@Override
public String getQueryMode() {
return QUERY_MODE;
}
@Override
public SemanticQueryReq buildSemanticQueryReq() {
if (!isMultiStructQuery()) {
return super.buildSemanticQueryReq();
}
return super.multiStructExecute();
}
protected boolean isMultiStructQuery() {
Set<String> filterBizName = new HashSet<>();
parseInfo.getDimensionFilters().stream().filter(filter -> filter.getElementID() != null)
.forEach(filter -> filterBizName.add(filter.getBizName()));
return FilterType.UNION.equals(parseInfo.getFilterType()) && filterBizName.size() > 1;
}
@Override
protected QueryStructReq convertQueryStruct() {
QueryStructReq queryStructReq = super.convertQueryStruct();
addDimension(queryStructReq, true);
return queryStructReq;
}
@Override
protected QueryMultiStructReq convertQueryMultiStruct() {
QueryMultiStructReq queryMultiStructReq = super.convertQueryMultiStruct();
for (QueryStructReq queryStructReq : queryMultiStructReq.getQueryStructReqs()) {
addDimension(queryStructReq, false);
}
return queryMultiStructReq;
}
private void addDimension(QueryStructReq queryStructReq, boolean onlyOperateInFilter) {
if (!queryStructReq.getDimensionFilters().isEmpty()) {
List<String> dimensions = queryStructReq.getGroups();
log.info("addDimension before [{}]", queryStructReq.getGroups());
List<Filter> filters = new ArrayList<>(queryStructReq.getDimensionFilters());
if (onlyOperateInFilter) {
filters = filters.stream()
.filter(filter -> filter.getOperator().equals(FilterOperatorEnum.IN))
.collect(Collectors.toList());
}
filters.forEach(d -> {
if (!dimensions.contains(d.getBizName())) {
dimensions.add(d.getBizName());
}
});
queryStructReq.setGroups(dimensions);
log.info("addDimension after [{}]", queryStructReq.getGroups());
}
}
}

View File

@@ -1,108 +0,0 @@
package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.*;
import com.tencent.supersonic.headless.api.pojo.DetailTypeDefaultConfig;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.springframework.core.env.Environment;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import static org.mockito.Mockito.when;
class SelectCorrectorTest {
Long dataSetId = 2L;
@Test
void testDoCorrect() {
MockedStatic<ContextUtils> mocked = Mockito.mockStatic(ContextUtils.class);
Environment mockEnvironment = Mockito.mock(Environment.class);
mocked.when(() -> ContextUtils.getBean(Environment.class)).thenReturn(mockEnvironment);
when(mockEnvironment.getProperty(SelectCorrector.ADDITIONAL_INFORMATION)).thenReturn("");
BaseSemanticCorrector corrector = new SelectCorrector();
ChatQueryContext chatQueryContext = buildQueryContext(dataSetId);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
SchemaElement dataSet = new SchemaElement();
dataSet.setDataSetId(dataSetId);
semanticParseInfo.setDataSet(dataSet);
semanticParseInfo.setQueryType(QueryType.DETAIL);
SqlInfo sqlInfo = new SqlInfo();
String sql = "SELECT * FROM 艺人库 WHERE 艺人名='周杰伦'";
sqlInfo.setParsedS2SQL(sql);
sqlInfo.setCorrectedS2SQL(sql);
semanticParseInfo.setSqlInfo(sqlInfo);
corrector.correct(chatQueryContext, semanticParseInfo);
Assert.assertEquals("SELECT 粉丝数, 国籍, 艺人名, 性别 FROM 艺人库 WHERE 艺人名 = '周杰伦'",
semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
}
private ChatQueryContext buildQueryContext(Long dataSetId) {
ChatQueryContext chatQueryContext = new ChatQueryContext();
List<DataSetSchema> dataSetSchemaList = new ArrayList<>();
DataSetSchema dataSetSchema = new DataSetSchema();
QueryConfig queryConfig = new QueryConfig();
DetailTypeDefaultConfig detailTypeDefaultConfig = new DetailTypeDefaultConfig();
DefaultDisplayInfo defaultDisplayInfo = new DefaultDisplayInfo();
List<Long> dimensionIds = new ArrayList<>();
dimensionIds.add(1L);
dimensionIds.add(2L);
dimensionIds.add(3L);
defaultDisplayInfo.setDimensionIds(dimensionIds);
List<Long> metricIds = new ArrayList<>();
metricIds.add(4L);
defaultDisplayInfo.setMetricIds(metricIds);
detailTypeDefaultConfig.setDefaultDisplayInfo(defaultDisplayInfo);
queryConfig.setDetailTypeDefaultConfig(detailTypeDefaultConfig);
dataSetSchema.setQueryConfig(queryConfig);
SchemaElement schemaElement = new SchemaElement();
schemaElement.setDataSetId(dataSetId);
dataSetSchema.setDataSet(schemaElement);
Set<SchemaElement> dimensions = new HashSet<>();
SchemaElement element1 = new SchemaElement();
element1.setDataSetId(dataSetId);
element1.setId(1L);
element1.setName("艺人名");
dimensions.add(element1);
SchemaElement element2 = new SchemaElement();
element2.setDataSetId(dataSetId);
element2.setId(2L);
element2.setName("性别");
dimensions.add(element2);
SchemaElement element3 = new SchemaElement();
element3.setDataSetId(dataSetId);
element3.setId(3L);
element3.setName("国籍");
dimensions.add(element3);
dataSetSchema.setDimensions(dimensions);
Set<SchemaElement> metrics = new HashSet<>();
SchemaElement metric1 = new SchemaElement();
metric1.setDataSetId(dataSetId);
metric1.setId(4L);
metric1.setName("粉丝数");
metrics.add(metric1);
dataSetSchema.setMetrics(metrics);
dataSetSchemaList.add(dataSetSchema);
SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList);
chatQueryContext.setSemanticSchema(semanticSchema);
return chatQueryContext;
}
}