(improvement) improve evalution accuracy (#727)

This commit is contained in:
mainmain
2024-02-19 17:31:38 +08:00
committed by GitHub
parent fdb69547e6
commit 699a33b1c1
13 changed files with 112 additions and 53 deletions

View File

@@ -77,7 +77,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
protected void addFieldsToSelect(SemanticParseInfo semanticParseInfo, String correctS2SQL) {
Set<String> selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL));
Set<String> needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL));
needAddFields.addAll(SqlSelectHelper.getOrderByFields(correctS2SQL));
//needAddFields.addAll(SqlSelectHelper.getOrderByFields(correctS2SQL));
// If there is no aggregate function in the S2SQL statement and
// there is a data field in 'WHERE' statement, add the field to the 'SELECT' statement.

View File

@@ -5,13 +5,21 @@ import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.headless.api.pojo.Dim;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import com.tencent.supersonic.headless.api.pojo.response.ViewResp;
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
import com.tencent.supersonic.headless.server.service.ModelService;
import com.tencent.supersonic.headless.server.service.ViewService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
@@ -23,11 +31,35 @@ public class GroupByCorrector extends BaseSemanticCorrector {
@Override
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
Boolean addGroupBy = addGroupBy(queryContext, semanticParseInfo);
log.info("addGroupBy:{}", addGroupBy);
if (!addGroupBy) {
return;
}
addGroupByFields(queryContext, semanticParseInfo);
}
private Boolean addGroupBy(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
Long viewId = semanticParseInfo.getViewId();
ViewService viewService = ContextUtils.getBean(ViewService.class);
ModelService modelService = ContextUtils.getBean(ModelService.class);
ViewResp viewResp = viewService.getView(viewId);
List<Long> modelIds = viewResp.getViewDetail().getViewModelConfigs().stream().map(config -> config.getId()
).collect(Collectors.toList());
MetaFilter metaFilter = new MetaFilter(modelIds);
List<ModelResp> modelRespList = modelService.getModelList(metaFilter);
for (ModelResp modelResp : modelRespList) {
List<Dim> dimList = modelResp.getModelDetail().getDimensions();
for (Dim dim : dimList) {
if (Objects.nonNull(dim.getTypeParams()) && dim.getTypeParams().getTimeGranularity().equals("none")) {
return false;
}
}
}
return true;
}
private void addGroupByFields(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
Long viewId = semanticParseInfo.getViewId();
//add dimension group by

View File

@@ -27,7 +27,7 @@ public class HavingCorrector extends BaseSemanticCorrector {
addHaving(queryContext, semanticParseInfo);
//add having expression filed to select
addHavingToSelect(semanticParseInfo);
//addHavingToSelect(semanticParseInfo);
}