mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 03:58:14 +00:00
llm Corrector group by optimize (#280)
This commit is contained in:
@@ -17,5 +17,5 @@ public class DimensionValueReq {
|
||||
private String bizName;
|
||||
|
||||
@NotNull
|
||||
private Object value;
|
||||
private String value;
|
||||
}
|
||||
|
||||
@@ -24,7 +24,6 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
|
||||
addGroupByFields(semanticCorrectInfo);
|
||||
|
||||
addAggregate(semanticCorrectInfo);
|
||||
}
|
||||
|
||||
private void addGroupByFields(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
@@ -52,6 +51,15 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) {
|
||||
return;
|
||||
}
|
||||
// if only date in select not add group by.
|
||||
if (selectFields.size() == 1 && selectFields.contains(DateUtils.DATE_FIELD)) {
|
||||
return;
|
||||
}
|
||||
if (SqlParserSelectHelper.hasGroupBy(sql)) {
|
||||
log.info("not add group by ,exist group by in sql:{}", sql);
|
||||
return;
|
||||
}
|
||||
|
||||
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(sql);
|
||||
Set<String> groupByFields = selectFields.stream()
|
||||
.filter(field -> dimensions.contains(field))
|
||||
@@ -63,6 +71,8 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
})
|
||||
.collect(Collectors.toSet());
|
||||
semanticCorrectInfo.setSql(SqlParserAddHelper.addGroupBy(sql, groupByFields));
|
||||
|
||||
addAggregate(semanticCorrectInfo);
|
||||
}
|
||||
|
||||
private void addAggregate(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
|
||||
@@ -95,6 +95,9 @@ public class HanlpDictMapper implements SchemaMapper {
|
||||
|
||||
SemanticService schemaService = ContextUtils.getBean(SemanticService.class);
|
||||
ModelSchema modelSchema = schemaService.getModelSchema(modelId);
|
||||
if (Objects.isNull(modelSchema)) {
|
||||
return;
|
||||
}
|
||||
|
||||
Long elementID = NatureHelper.getElementID(nature);
|
||||
Long frequency = wordNatureToFrequency.get(mapResult.getName() + nature);
|
||||
|
||||
@@ -46,10 +46,10 @@ import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
@@ -395,9 +395,21 @@ public class LLMS2QLParser implements SemanticParser {
|
||||
List<SchemaElement> allElements = Lists.newArrayList();
|
||||
allElements.addAll(dimensions);
|
||||
allElements.addAll(metrics);
|
||||
//support alias
|
||||
return allElements.stream()
|
||||
.filter(schemaElement -> schemaElement.getModel().equals(modelId))
|
||||
.collect(Collectors.toMap(SchemaElement::getName, Function.identity(), (value1, value2) -> value2));
|
||||
.flatMap(schemaElement -> {
|
||||
Set<Pair<String, SchemaElement>> result = new HashSet<>();
|
||||
result.add(Pair.of(schemaElement.getName(), schemaElement));
|
||||
List<String> aliasList = schemaElement.getAlias();
|
||||
if (!CollectionUtils.isEmpty(aliasList)) {
|
||||
for (String alias : aliasList) {
|
||||
result.add(Pair.of(alias, schemaElement));
|
||||
}
|
||||
}
|
||||
return result.stream();
|
||||
})
|
||||
.collect(Collectors.toMap(pair -> pair.getLeft(), pair -> pair.getRight(), (value1, value2) -> value2));
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -50,8 +50,8 @@ import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.knowledge.dictionary.MapResult;
|
||||
import com.tencent.supersonic.knowledge.dictionary.MultiCustomDictionary;
|
||||
import com.tencent.supersonic.knowledge.service.SearchService;
|
||||
import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
||||
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
||||
import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
|
||||
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
|
||||
@@ -594,35 +594,11 @@ public class QueryServiceImpl implements QueryService {
|
||||
|
||||
@Override
|
||||
public Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception {
|
||||
if (StringUtils.isBlank(dimensionValueReq.getValue().toString())) {
|
||||
String nature =
|
||||
dimensionValueReq.getModelId() + DictWordType.NATURE_SPILT + dimensionValueReq.getElementID();
|
||||
PriorityQueue<Term> terms = MultiCustomDictionary.NATURE_TO_VALUES.get(nature);
|
||||
if (CollectionUtils.isEmpty(terms)) {
|
||||
return null;
|
||||
}
|
||||
return terms.stream().map(term -> term.getWord()).collect(Collectors.toSet());
|
||||
}
|
||||
return queryHanlpDimensionValue(dimensionValueReq, user);
|
||||
}
|
||||
|
||||
public Object queryHanlpDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception {
|
||||
QueryResultWithSchemaResp queryResultWithSchemaResp = new QueryResultWithSchemaResp();
|
||||
Set<Long> detectModelIds = new HashSet<>();
|
||||
detectModelIds.add(dimensionValueReq.getModelId());
|
||||
List<MapResult> mapResultList = SearchService.prefixSearch(dimensionValueReq.getValue().toString(),
|
||||
2000, dimensionValueReq.getAgentId(), detectModelIds);
|
||||
HanlpHelper.transLetterOriginal(mapResultList);
|
||||
mapResultList = mapResultList.stream().filter(o -> {
|
||||
for (String nature : o.getNatures()) {
|
||||
Long elementID = NatureHelper.getElementID(nature);
|
||||
if (dimensionValueReq.getElementID().equals(elementID)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}).collect(Collectors.toList());
|
||||
log.info("mapResultList:{}", mapResultList);
|
||||
List<String> dimensionValues = getDimensionValues(dimensionValueReq, detectModelIds);
|
||||
|
||||
List<QueryColumn> columns = new ArrayList<>();
|
||||
QueryColumn queryColumn = new QueryColumn();
|
||||
queryColumn.setNameEn(dimensionValueReq.getBizName());
|
||||
@@ -631,9 +607,9 @@ public class QueryServiceImpl implements QueryService {
|
||||
queryColumn.setType("CHAR");
|
||||
columns.add(queryColumn);
|
||||
List<Map<String, Object>> resultList = new ArrayList<>();
|
||||
mapResultList.stream().forEach(o -> {
|
||||
dimensionValues.stream().forEach(o -> {
|
||||
Map<String, Object> map = new HashMap<>();
|
||||
map.put(dimensionValueReq.getBizName(), o.getName());
|
||||
map.put(dimensionValueReq.getBizName(), o);
|
||||
resultList.add(map);
|
||||
});
|
||||
queryResultWithSchemaResp.setColumns(columns);
|
||||
@@ -641,5 +617,34 @@ public class QueryServiceImpl implements QueryService {
|
||||
return queryResultWithSchemaResp;
|
||||
}
|
||||
|
||||
private List<String> getDimensionValues(DimensionValueReq dimensionValueReq, Set<Long> detectModelIds) {
|
||||
//if value is null ,then search from NATURE_TO_VALUES
|
||||
if (StringUtils.isBlank(dimensionValueReq.getValue())) {
|
||||
String nature = DictWordType.NATURE_SPILT + dimensionValueReq.getModelId() + DictWordType.NATURE_SPILT
|
||||
+ dimensionValueReq.getElementID();
|
||||
PriorityQueue<Term> terms = MultiCustomDictionary.NATURE_TO_VALUES.get(nature);
|
||||
if (CollectionUtils.isEmpty(terms)) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
return terms.stream().map(term -> term.getWord()).collect(Collectors.toList());
|
||||
}
|
||||
//search from prefixSearch
|
||||
List<MapResult> mapResultList = SearchService.prefixSearch(dimensionValueReq.getValue(),
|
||||
2000, dimensionValueReq.getAgentId(), detectModelIds);
|
||||
HanlpHelper.transLetterOriginal(mapResultList);
|
||||
return mapResultList.stream()
|
||||
.filter(o -> {
|
||||
for (String nature : o.getNatures()) {
|
||||
Long elementID = NatureHelper.getElementID(nature);
|
||||
if (dimensionValueReq.getElementID().equals(elementID)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
})
|
||||
.map(mapResult -> mapResult.getName())
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user