llm Corrector group by optimize (#280)

This commit is contained in:
lexluo09
2023-10-23 23:13:20 +08:00
committed by GitHub
parent d8f81aca65
commit 8fde378534
6 changed files with 66 additions and 35 deletions

View File

@@ -17,5 +17,5 @@ public class DimensionValueReq {
private String bizName;
@NotNull
private Object value;
private String value;
}

View File

@@ -24,7 +24,6 @@ public class GroupByCorrector extends BaseSemanticCorrector {
addGroupByFields(semanticCorrectInfo);
addAggregate(semanticCorrectInfo);
}
private void addGroupByFields(SemanticCorrectInfo semanticCorrectInfo) {
@@ -52,6 +51,15 @@ public class GroupByCorrector extends BaseSemanticCorrector {
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) {
return;
}
// if only date in select not add group by.
if (selectFields.size() == 1 && selectFields.contains(DateUtils.DATE_FIELD)) {
return;
}
if (SqlParserSelectHelper.hasGroupBy(sql)) {
log.info("not add group by ,exist group by in sql:{}", sql);
return;
}
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(sql);
Set<String> groupByFields = selectFields.stream()
.filter(field -> dimensions.contains(field))
@@ -63,6 +71,8 @@ public class GroupByCorrector extends BaseSemanticCorrector {
})
.collect(Collectors.toSet());
semanticCorrectInfo.setSql(SqlParserAddHelper.addGroupBy(sql, groupByFields));
addAggregate(semanticCorrectInfo);
}
private void addAggregate(SemanticCorrectInfo semanticCorrectInfo) {

View File

@@ -95,6 +95,9 @@ public class HanlpDictMapper implements SchemaMapper {
SemanticService schemaService = ContextUtils.getBean(SemanticService.class);
ModelSchema modelSchema = schemaService.getModelSchema(modelId);
if (Objects.isNull(modelSchema)) {
return;
}
Long elementID = NatureHelper.getElementID(nature);
Long frequency = wordNatureToFrequency.get(mapResult.getName() + nature);

View File

@@ -46,10 +46,10 @@ import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
@@ -395,9 +395,21 @@ public class LLMS2QLParser implements SemanticParser {
List<SchemaElement> allElements = Lists.newArrayList();
allElements.addAll(dimensions);
allElements.addAll(metrics);
//support alias
return allElements.stream()
.filter(schemaElement -> schemaElement.getModel().equals(modelId))
.collect(Collectors.toMap(SchemaElement::getName, Function.identity(), (value1, value2) -> value2));
.flatMap(schemaElement -> {
Set<Pair<String, SchemaElement>> result = new HashSet<>();
result.add(Pair.of(schemaElement.getName(), schemaElement));
List<String> aliasList = schemaElement.getAlias();
if (!CollectionUtils.isEmpty(aliasList)) {
for (String alias : aliasList) {
result.add(Pair.of(alias, schemaElement));
}
}
return result.stream();
})
.collect(Collectors.toMap(pair -> pair.getLeft(), pair -> pair.getRight(), (value1, value2) -> value2));
}

View File

@@ -50,8 +50,8 @@ import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.knowledge.dictionary.MultiCustomDictionary;
import com.tencent.supersonic.knowledge.service.SearchService;
import com.tencent.supersonic.knowledge.utils.NatureHelper;
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
import com.tencent.supersonic.knowledge.utils.NatureHelper;
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
@@ -594,35 +594,11 @@ public class QueryServiceImpl implements QueryService {
@Override
public Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception {
if (StringUtils.isBlank(dimensionValueReq.getValue().toString())) {
String nature =
dimensionValueReq.getModelId() + DictWordType.NATURE_SPILT + dimensionValueReq.getElementID();
PriorityQueue<Term> terms = MultiCustomDictionary.NATURE_TO_VALUES.get(nature);
if (CollectionUtils.isEmpty(terms)) {
return null;
}
return terms.stream().map(term -> term.getWord()).collect(Collectors.toSet());
}
return queryHanlpDimensionValue(dimensionValueReq, user);
}
public Object queryHanlpDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception {
QueryResultWithSchemaResp queryResultWithSchemaResp = new QueryResultWithSchemaResp();
Set<Long> detectModelIds = new HashSet<>();
detectModelIds.add(dimensionValueReq.getModelId());
List<MapResult> mapResultList = SearchService.prefixSearch(dimensionValueReq.getValue().toString(),
2000, dimensionValueReq.getAgentId(), detectModelIds);
HanlpHelper.transLetterOriginal(mapResultList);
mapResultList = mapResultList.stream().filter(o -> {
for (String nature : o.getNatures()) {
Long elementID = NatureHelper.getElementID(nature);
if (dimensionValueReq.getElementID().equals(elementID)) {
return true;
}
}
return false;
}).collect(Collectors.toList());
log.info("mapResultList:{}", mapResultList);
List<String> dimensionValues = getDimensionValues(dimensionValueReq, detectModelIds);
List<QueryColumn> columns = new ArrayList<>();
QueryColumn queryColumn = new QueryColumn();
queryColumn.setNameEn(dimensionValueReq.getBizName());
@@ -631,9 +607,9 @@ public class QueryServiceImpl implements QueryService {
queryColumn.setType("CHAR");
columns.add(queryColumn);
List<Map<String, Object>> resultList = new ArrayList<>();
mapResultList.stream().forEach(o -> {
dimensionValues.stream().forEach(o -> {
Map<String, Object> map = new HashMap<>();
map.put(dimensionValueReq.getBizName(), o.getName());
map.put(dimensionValueReq.getBizName(), o);
resultList.add(map);
});
queryResultWithSchemaResp.setColumns(columns);
@@ -641,5 +617,34 @@ public class QueryServiceImpl implements QueryService {
return queryResultWithSchemaResp;
}
private List<String> getDimensionValues(DimensionValueReq dimensionValueReq, Set<Long> detectModelIds) {
//if value is null ,then search from NATURE_TO_VALUES
if (StringUtils.isBlank(dimensionValueReq.getValue())) {
String nature = DictWordType.NATURE_SPILT + dimensionValueReq.getModelId() + DictWordType.NATURE_SPILT
+ dimensionValueReq.getElementID();
PriorityQueue<Term> terms = MultiCustomDictionary.NATURE_TO_VALUES.get(nature);
if (CollectionUtils.isEmpty(terms)) {
return new ArrayList<>();
}
return terms.stream().map(term -> term.getWord()).collect(Collectors.toList());
}
//search from prefixSearch
List<MapResult> mapResultList = SearchService.prefixSearch(dimensionValueReq.getValue(),
2000, dimensionValueReq.getAgentId(), detectModelIds);
HanlpHelper.transLetterOriginal(mapResultList);
return mapResultList.stream()
.filter(o -> {
for (String nature : o.getNatures()) {
Long elementID = NatureHelper.getElementID(nature);
if (dimensionValueReq.getElementID().equals(elementID)) {
return true;
}
}
return false;
})
.map(mapResult -> mapResult.getName())
.collect(Collectors.toList());
}
}