mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 04:27:39 +00:00
(improvement)(chat) Optimize the update parserInfo code and resolve compilation exceptions (#346)
This commit is contained in:
@@ -20,7 +20,5 @@ public interface SemanticQuery {
|
||||
|
||||
SemanticParseInfo getParseInfo();
|
||||
|
||||
void updateParseInfo();
|
||||
|
||||
void setParseInfo(SemanticParseInfo parseInfo);
|
||||
}
|
||||
|
||||
@@ -1,28 +1,17 @@
|
||||
|
||||
package com.tencent.supersonic.chat.query;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
|
||||
import com.tencent.supersonic.common.pojo.Aggregator;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
|
||||
import com.tencent.supersonic.common.pojo.Filter;
|
||||
import com.tencent.supersonic.common.pojo.Order;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.FilterExpression;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import com.tencent.supersonic.semantic.api.model.enums.QueryTypeEnum;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
|
||||
@@ -30,19 +19,13 @@ import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryS2QLReq;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.ToString;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
|
||||
@Slf4j
|
||||
@ToString
|
||||
@@ -87,167 +70,6 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
|
||||
return QueryReqBuilder.buildStructReq(parseInfo);
|
||||
}
|
||||
|
||||
public void updateParseInfo() {
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
String logicSql = sqlInfo.getLogicSql();
|
||||
if (StringUtils.isBlank(logicSql)) {
|
||||
return;
|
||||
}
|
||||
|
||||
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(logicSql);
|
||||
//set dataInfo
|
||||
try {
|
||||
if (!org.springframework.util.CollectionUtils.isEmpty(expressions)) {
|
||||
DateConf dateInfo = getDateInfo(expressions);
|
||||
parseInfo.setDateInfo(dateInfo);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("set dateInfo error :", e);
|
||||
}
|
||||
|
||||
//set filter
|
||||
try {
|
||||
Map<String, SchemaElement> fieldNameToElement = getNameToElement(parseInfo.getModelId());
|
||||
List<QueryFilter> result = getDimensionFilter(fieldNameToElement, expressions);
|
||||
parseInfo.getDimensionFilters().addAll(result);
|
||||
} catch (Exception e) {
|
||||
log.error("set dimensionFilter error :", e);
|
||||
}
|
||||
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
|
||||
if (Objects.isNull(semanticSchema)) {
|
||||
return;
|
||||
}
|
||||
List<String> allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getLogicSql()));
|
||||
|
||||
Set<SchemaElement> metrics = getElements(parseInfo.getModelId(), allFields, semanticSchema.getMetrics());
|
||||
parseInfo.setMetrics(metrics);
|
||||
|
||||
if (SqlParserSelectFunctionHelper.hasAggregateFunction(sqlInfo.getLogicSql())) {
|
||||
parseInfo.setNativeQuery(false);
|
||||
List<String> groupByFields = SqlParserSelectHelper.getGroupByFields(sqlInfo.getLogicSql());
|
||||
List<String> groupByDimensions = getFieldsExceptDate(groupByFields);
|
||||
parseInfo.setDimensions(
|
||||
getElements(parseInfo.getModelId(), groupByDimensions, semanticSchema.getDimensions()));
|
||||
} else {
|
||||
parseInfo.setNativeQuery(true);
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getLogicSql());
|
||||
List<String> selectDimensions = getFieldsExceptDate(selectFields);
|
||||
parseInfo.setDimensions(
|
||||
getElements(parseInfo.getModelId(), selectDimensions, semanticSchema.getDimensions()));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private Set<SchemaElement> getElements(Long modelId, List<String> allFields, List<SchemaElement> elements) {
|
||||
return elements.stream()
|
||||
.filter(schemaElement -> modelId.equals(schemaElement.getModel())
|
||||
&& allFields.contains(schemaElement.getName())
|
||||
).collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
private List<String> getFieldsExceptDate(List<String> allFields) {
|
||||
if (org.springframework.util.CollectionUtils.isEmpty(allFields)) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
return allFields.stream()
|
||||
.filter(entry -> !TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(entry))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
|
||||
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> fieldNameToElement,
|
||||
List<FilterExpression> filterExpressions) {
|
||||
List<QueryFilter> result = Lists.newArrayList();
|
||||
for (FilterExpression expression : filterExpressions) {
|
||||
QueryFilter dimensionFilter = new QueryFilter();
|
||||
dimensionFilter.setValue(expression.getFieldValue());
|
||||
SchemaElement schemaElement = fieldNameToElement.get(expression.getFieldName());
|
||||
if (Objects.isNull(schemaElement)) {
|
||||
continue;
|
||||
}
|
||||
dimensionFilter.setName(schemaElement.getName());
|
||||
dimensionFilter.setBizName(schemaElement.getBizName());
|
||||
dimensionFilter.setElementID(schemaElement.getId());
|
||||
|
||||
FilterOperatorEnum operatorEnum = FilterOperatorEnum.getSqlOperator(expression.getOperator());
|
||||
dimensionFilter.setOperator(operatorEnum);
|
||||
dimensionFilter.setFunction(expression.getFunction());
|
||||
result.add(dimensionFilter);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private DateConf getDateInfo(List<FilterExpression> filterExpressions) {
|
||||
List<FilterExpression> dateExpressions = filterExpressions.stream()
|
||||
.filter(expression -> TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(expression.getFieldName()))
|
||||
.collect(Collectors.toList());
|
||||
if (org.springframework.util.CollectionUtils.isEmpty(dateExpressions)) {
|
||||
return new DateConf();
|
||||
}
|
||||
DateConf dateInfo = new DateConf();
|
||||
dateInfo.setDateMode(DateMode.BETWEEN);
|
||||
FilterExpression firstExpression = dateExpressions.get(0);
|
||||
|
||||
FilterOperatorEnum firstOperator = FilterOperatorEnum.getSqlOperator(firstExpression.getOperator());
|
||||
if (FilterOperatorEnum.EQUALS.equals(firstOperator) && Objects.nonNull(firstExpression.getFieldValue())) {
|
||||
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
|
||||
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
|
||||
dateInfo.setDateMode(DateMode.BETWEEN);
|
||||
return dateInfo;
|
||||
}
|
||||
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.GREATER_THAN,
|
||||
FilterOperatorEnum.GREATER_THAN_EQUALS)) {
|
||||
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
|
||||
if (hasSecondDate(dateExpressions)) {
|
||||
dateInfo.setEndDate(dateExpressions.get(1).getFieldValue().toString());
|
||||
}
|
||||
}
|
||||
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.MINOR_THAN,
|
||||
FilterOperatorEnum.MINOR_THAN_EQUALS)) {
|
||||
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
|
||||
if (hasSecondDate(dateExpressions)) {
|
||||
dateInfo.setStartDate(dateExpressions.get(1).getFieldValue().toString());
|
||||
}
|
||||
}
|
||||
return dateInfo;
|
||||
}
|
||||
|
||||
private boolean containOperators(FilterExpression expression, FilterOperatorEnum firstOperator,
|
||||
FilterOperatorEnum... operatorEnums) {
|
||||
return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(expression.getFieldValue()));
|
||||
}
|
||||
|
||||
private boolean hasSecondDate(List<FilterExpression> dateExpressions) {
|
||||
return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue());
|
||||
}
|
||||
|
||||
protected Map<String, SchemaElement> getNameToElement(Long modelId) {
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
List<SchemaElement> dimensions = semanticSchema.getDimensions();
|
||||
List<SchemaElement> metrics = semanticSchema.getMetrics();
|
||||
|
||||
List<SchemaElement> allElements = Lists.newArrayList();
|
||||
allElements.addAll(dimensions);
|
||||
allElements.addAll(metrics);
|
||||
//support alias
|
||||
return allElements.stream()
|
||||
.filter(schemaElement -> schemaElement.getModel().equals(modelId))
|
||||
.flatMap(schemaElement -> {
|
||||
Set<Pair<String, SchemaElement>> result = new HashSet<>();
|
||||
result.add(Pair.of(schemaElement.getName(), schemaElement));
|
||||
List<String> aliasList = schemaElement.getAlias();
|
||||
if (!org.springframework.util.CollectionUtils.isEmpty(aliasList)) {
|
||||
for (String alias : aliasList) {
|
||||
result.add(Pair.of(alias, schemaElement));
|
||||
}
|
||||
}
|
||||
return result.stream();
|
||||
})
|
||||
.collect(Collectors.toMap(pair -> pair.getLeft(), pair -> pair.getRight(), (value1, value2) -> value2));
|
||||
}
|
||||
|
||||
protected void convertBizNameToName(QueryStructReq queryStructReq) {
|
||||
SchemaService schemaService = ContextUtils.getBean(SchemaService.class);
|
||||
Map<String, String> bizNameToName = schemaService.getSemanticSchema()
|
||||
|
||||
@@ -14,6 +14,7 @@ import java.util.OptionalDouble;
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
@@ -22,6 +23,7 @@ public class HeuristicQuerySelector implements QuerySelector {
|
||||
|
||||
@Override
|
||||
public List<SemanticQuery> select(List<SemanticQuery> candidateQueries, QueryReq queryReq) {
|
||||
log.debug("pick before [{}]", candidateQueries.stream().collect(Collectors.toList()));
|
||||
List<SemanticQuery> selectedQueries = new ArrayList<>();
|
||||
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
|
||||
Double candidateThreshold = optimizationConfig.getCandidateThreshold();
|
||||
@@ -45,6 +47,7 @@ public class HeuristicQuerySelector implements QuerySelector {
|
||||
});
|
||||
}
|
||||
}
|
||||
log.debug("pick after [{}]", selectedQueries.stream().collect(Collectors.toList()));
|
||||
return selectedQueries;
|
||||
}
|
||||
|
||||
|
||||
@@ -3,17 +3,16 @@ package com.tencent.supersonic.chat.service;
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
||||
import java.util.List;
|
||||
|
||||
public interface ChatService {
|
||||
@@ -49,10 +48,7 @@ public interface ChatService {
|
||||
|
||||
void addQuery(QueryResult queryResult, ChatContext chatCtx);
|
||||
|
||||
List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryReq queryReq,
|
||||
ParseResp parseResult,
|
||||
List<SemanticParseInfo> candidateParses,
|
||||
List<SemanticParseInfo> selectedParses);
|
||||
List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryReq queryReq, ParseResp parseResult);
|
||||
|
||||
void updateChatParse(List<ChatParseDO> chatParseDOS);
|
||||
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
package com.tencent.supersonic.chat.service;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import java.util.List;
|
||||
|
||||
public interface ParseInfoService {
|
||||
|
||||
List<SemanticParseInfo> getTopCandidateParseInfo(List<SemanticParseInfo> selectedParses,
|
||||
List<SemanticParseInfo> candidateParses);
|
||||
|
||||
List<SemanticParseInfo> sortParseInfo(List<SemanticQuery> semanticQueries);
|
||||
|
||||
void updateParseInfo(SemanticParseInfo parseInfo);
|
||||
|
||||
}
|
||||
@@ -223,10 +223,9 @@ public class ChatServiceImpl implements ChatService {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryReq queryReq,
|
||||
ParseResp parseResult,
|
||||
List<SemanticParseInfo> candidateParses,
|
||||
List<SemanticParseInfo> selectedParses) {
|
||||
public List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryReq queryReq, ParseResp parseResult) {
|
||||
List<SemanticParseInfo> candidateParses = parseResult.getCandidateParses();
|
||||
List<SemanticParseInfo> selectedParses = parseResult.getSelectedParses();
|
||||
return chatQueryRepository.batchSaveParseInfo(chatCtx, queryReq, parseResult, candidateParses, selectedParses);
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,234 @@
|
||||
|
||||
package com.tencent.supersonic.chat.service.impl;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.chat.service.ParseInfoService;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.FilterExpression;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
public class ParserInfoServiceImpl implements ParseInfoService {
|
||||
|
||||
@Value("${candidate.top.size:5}")
|
||||
private int candidateTopSize;
|
||||
|
||||
public List<SemanticParseInfo> getTopCandidateParseInfo(List<SemanticParseInfo> selectedParses,
|
||||
List<SemanticParseInfo> candidateParses) {
|
||||
if (CollectionUtils.isEmpty(selectedParses) || CollectionUtils.isEmpty(candidateParses)) {
|
||||
return candidateParses;
|
||||
}
|
||||
int selectParseSize = selectedParses.size();
|
||||
Set<Double> selectParseScoreSet = selectedParses.stream()
|
||||
.map(SemanticParseInfo::getScore).collect(Collectors.toSet());
|
||||
int candidateParseSize = candidateTopSize - selectParseSize;
|
||||
candidateParses = candidateParses.stream()
|
||||
.filter(candidateParse -> !selectParseScoreSet.contains(candidateParse.getScore()))
|
||||
.collect(Collectors.toList());
|
||||
SemanticParseInfo semanticParseInfo = selectedParses.get(0);
|
||||
Long modelId = semanticParseInfo.getModelId();
|
||||
if (modelId == null || modelId <= 0) {
|
||||
return candidateParses;
|
||||
}
|
||||
return candidateParses.stream()
|
||||
.sorted(Comparator.comparing(parse -> !parse.getModelId().equals(modelId)))
|
||||
.limit(candidateParseSize)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public List<SemanticParseInfo> sortParseInfo(List<SemanticQuery> semanticQueries) {
|
||||
return semanticQueries.stream()
|
||||
.map(SemanticQuery::getParseInfo)
|
||||
.sorted(Comparator.comparingDouble(SemanticParseInfo::getScore).reversed())
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public void updateParseInfo(SemanticParseInfo parseInfo) {
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
String logicSql = sqlInfo.getLogicSql();
|
||||
if (StringUtils.isBlank(logicSql)) {
|
||||
return;
|
||||
}
|
||||
|
||||
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(logicSql);
|
||||
//set dataInfo
|
||||
try {
|
||||
if (!org.springframework.util.CollectionUtils.isEmpty(expressions)) {
|
||||
DateConf dateInfo = getDateInfo(expressions);
|
||||
parseInfo.setDateInfo(dateInfo);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("set dateInfo error :", e);
|
||||
}
|
||||
|
||||
//set filter
|
||||
try {
|
||||
Map<String, SchemaElement> fieldNameToElement = getNameToElement(parseInfo.getModelId());
|
||||
List<QueryFilter> result = getDimensionFilter(fieldNameToElement, expressions);
|
||||
parseInfo.getDimensionFilters().addAll(result);
|
||||
} catch (Exception e) {
|
||||
log.error("set dimensionFilter error :", e);
|
||||
}
|
||||
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
|
||||
if (Objects.isNull(semanticSchema)) {
|
||||
return;
|
||||
}
|
||||
List<String> allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getLogicSql()));
|
||||
|
||||
Set<SchemaElement> metrics = getElements(parseInfo.getModelId(), allFields, semanticSchema.getMetrics());
|
||||
parseInfo.setMetrics(metrics);
|
||||
|
||||
if (SqlParserSelectFunctionHelper.hasAggregateFunction(sqlInfo.getLogicSql())) {
|
||||
parseInfo.setNativeQuery(false);
|
||||
List<String> groupByFields = SqlParserSelectHelper.getGroupByFields(sqlInfo.getLogicSql());
|
||||
List<String> groupByDimensions = getFieldsExceptDate(groupByFields);
|
||||
parseInfo.setDimensions(
|
||||
getElements(parseInfo.getModelId(), groupByDimensions, semanticSchema.getDimensions()));
|
||||
} else {
|
||||
parseInfo.setNativeQuery(true);
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getLogicSql());
|
||||
List<String> selectDimensions = getFieldsExceptDate(selectFields);
|
||||
parseInfo.setDimensions(
|
||||
getElements(parseInfo.getModelId(), selectDimensions, semanticSchema.getDimensions()));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private Set<SchemaElement> getElements(Long modelId, List<String> allFields, List<SchemaElement> elements) {
|
||||
return elements.stream()
|
||||
.filter(schemaElement -> modelId.equals(schemaElement.getModel())
|
||||
&& allFields.contains(schemaElement.getName())
|
||||
).collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
private List<String> getFieldsExceptDate(List<String> allFields) {
|
||||
if (org.springframework.util.CollectionUtils.isEmpty(allFields)) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
return allFields.stream()
|
||||
.filter(entry -> !TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(entry))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
|
||||
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> fieldNameToElement,
|
||||
List<FilterExpression> filterExpressions) {
|
||||
List<QueryFilter> result = Lists.newArrayList();
|
||||
for (FilterExpression expression : filterExpressions) {
|
||||
QueryFilter dimensionFilter = new QueryFilter();
|
||||
dimensionFilter.setValue(expression.getFieldValue());
|
||||
SchemaElement schemaElement = fieldNameToElement.get(expression.getFieldName());
|
||||
if (Objects.isNull(schemaElement)) {
|
||||
continue;
|
||||
}
|
||||
dimensionFilter.setName(schemaElement.getName());
|
||||
dimensionFilter.setBizName(schemaElement.getBizName());
|
||||
dimensionFilter.setElementID(schemaElement.getId());
|
||||
|
||||
FilterOperatorEnum operatorEnum = FilterOperatorEnum.getSqlOperator(expression.getOperator());
|
||||
dimensionFilter.setOperator(operatorEnum);
|
||||
dimensionFilter.setFunction(expression.getFunction());
|
||||
result.add(dimensionFilter);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private DateConf getDateInfo(List<FilterExpression> filterExpressions) {
|
||||
List<FilterExpression> dateExpressions = filterExpressions.stream()
|
||||
.filter(expression -> TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(expression.getFieldName()))
|
||||
.collect(Collectors.toList());
|
||||
if (org.springframework.util.CollectionUtils.isEmpty(dateExpressions)) {
|
||||
return new DateConf();
|
||||
}
|
||||
DateConf dateInfo = new DateConf();
|
||||
dateInfo.setDateMode(DateMode.BETWEEN);
|
||||
FilterExpression firstExpression = dateExpressions.get(0);
|
||||
|
||||
FilterOperatorEnum firstOperator = FilterOperatorEnum.getSqlOperator(firstExpression.getOperator());
|
||||
if (FilterOperatorEnum.EQUALS.equals(firstOperator) && Objects.nonNull(firstExpression.getFieldValue())) {
|
||||
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
|
||||
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
|
||||
dateInfo.setDateMode(DateMode.BETWEEN);
|
||||
return dateInfo;
|
||||
}
|
||||
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.GREATER_THAN,
|
||||
FilterOperatorEnum.GREATER_THAN_EQUALS)) {
|
||||
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
|
||||
if (hasSecondDate(dateExpressions)) {
|
||||
dateInfo.setEndDate(dateExpressions.get(1).getFieldValue().toString());
|
||||
}
|
||||
}
|
||||
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.MINOR_THAN,
|
||||
FilterOperatorEnum.MINOR_THAN_EQUALS)) {
|
||||
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
|
||||
if (hasSecondDate(dateExpressions)) {
|
||||
dateInfo.setStartDate(dateExpressions.get(1).getFieldValue().toString());
|
||||
}
|
||||
}
|
||||
return dateInfo;
|
||||
}
|
||||
|
||||
private boolean containOperators(FilterExpression expression, FilterOperatorEnum firstOperator,
|
||||
FilterOperatorEnum... operatorEnums) {
|
||||
return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(expression.getFieldValue()));
|
||||
}
|
||||
|
||||
private boolean hasSecondDate(List<FilterExpression> dateExpressions) {
|
||||
return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue());
|
||||
}
|
||||
|
||||
protected Map<String, SchemaElement> getNameToElement(Long modelId) {
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
List<SchemaElement> dimensions = semanticSchema.getDimensions();
|
||||
List<SchemaElement> metrics = semanticSchema.getMetrics();
|
||||
|
||||
List<SchemaElement> allElements = Lists.newArrayList();
|
||||
allElements.addAll(dimensions);
|
||||
allElements.addAll(metrics);
|
||||
//support alias
|
||||
return allElements.stream()
|
||||
.filter(schemaElement -> schemaElement.getModel().equals(modelId))
|
||||
.flatMap(schemaElement -> {
|
||||
Set<Pair<String, SchemaElement>> result = new HashSet<>();
|
||||
result.add(Pair.of(schemaElement.getName(), schemaElement));
|
||||
List<String> aliasList = schemaElement.getAlias();
|
||||
if (!org.springframework.util.CollectionUtils.isEmpty(aliasList)) {
|
||||
for (String alias : aliasList) {
|
||||
result.add(Pair.of(alias, schemaElement));
|
||||
}
|
||||
}
|
||||
return result.stream();
|
||||
})
|
||||
.collect(Collectors.toMap(pair -> pair.getLeft(), pair -> pair.getRight(), (value1, value2) -> value2));
|
||||
}
|
||||
}
|
||||
@@ -31,6 +31,7 @@ import com.tencent.supersonic.chat.query.llm.s2ql.S2QLQuery;
|
||||
import com.tencent.supersonic.chat.responder.execute.ExecuteResponder;
|
||||
import com.tencent.supersonic.chat.responder.parse.ParseResponder;
|
||||
import com.tencent.supersonic.chat.service.ChatService;
|
||||
import com.tencent.supersonic.chat.service.ParseInfoService;
|
||||
import com.tencent.supersonic.chat.service.QueryService;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.chat.service.StatisticsService;
|
||||
@@ -58,7 +59,6 @@ import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
||||
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
@@ -102,6 +102,8 @@ public class QueryServiceImpl implements QueryService {
|
||||
private StatisticsService statisticsService;
|
||||
@Autowired
|
||||
private SolvedQueryManager solvedQueryManager;
|
||||
@Autowired
|
||||
private ParseInfoService parseInfoService;
|
||||
|
||||
@Value("${time.threshold: 100}")
|
||||
private Integer timeThreshold;
|
||||
@@ -148,7 +150,6 @@ public class QueryServiceImpl implements QueryService {
|
||||
semanticCorrectors.stream().forEach(correction -> {
|
||||
correction.correct(queryReq, semanticQuery.getParseInfo());
|
||||
});
|
||||
semanticQuery.updateParseInfo();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -156,15 +157,14 @@ public class QueryServiceImpl implements QueryService {
|
||||
ParseResp parseResult;
|
||||
List<ChatParseDO> chatParseDOS = Lists.newArrayList();
|
||||
if (candidateQueries.size() > 0) {
|
||||
log.debug("pick before [{}]", candidateQueries.stream().collect(
|
||||
Collectors.toList()));
|
||||
List<SemanticQuery> selectedQueries = querySelector.select(candidateQueries, queryReq);
|
||||
log.debug("pick after [{}]", selectedQueries.stream().collect(
|
||||
Collectors.toList()));
|
||||
|
||||
List<SemanticParseInfo> selectedParses = convertParseInfo(selectedQueries);
|
||||
List<SemanticParseInfo> candidateParses = convertParseInfo(candidateQueries);
|
||||
candidateParses = getTop5CandidateParseInfo(selectedParses, candidateParses);
|
||||
candidateQueries.forEach(semanticQuery -> parseInfoService.updateParseInfo(semanticQuery.getParseInfo()));
|
||||
List<SemanticParseInfo> selectedParses = parseInfoService.sortParseInfo(selectedQueries);
|
||||
List<SemanticParseInfo> candidateParses = parseInfoService.sortParseInfo(candidateQueries);
|
||||
|
||||
candidateParses = parseInfoService.getTopCandidateParseInfo(selectedParses, candidateParses);
|
||||
|
||||
parseResult = ParseResp.builder()
|
||||
.chatId(queryReq.getChatId())
|
||||
.queryText(queryReq.getQueryText())
|
||||
@@ -172,7 +172,7 @@ public class QueryServiceImpl implements QueryService {
|
||||
.selectedParses(selectedParses)
|
||||
.candidateParses(candidateParses)
|
||||
.build();
|
||||
chatParseDOS = chatService.batchAddParse(chatCtx, queryReq, parseResult, candidateParses, selectedParses);
|
||||
chatParseDOS = chatService.batchAddParse(chatCtx, queryReq, parseResult);
|
||||
} else {
|
||||
parseResult = ParseResp.builder()
|
||||
.chatId(queryReq.getChatId())
|
||||
@@ -198,35 +198,6 @@ public class QueryServiceImpl implements QueryService {
|
||||
return parseResult;
|
||||
}
|
||||
|
||||
private List<SemanticParseInfo> convertParseInfo(List<SemanticQuery> semanticQueries) {
|
||||
return semanticQueries.stream()
|
||||
.map(SemanticQuery::getParseInfo)
|
||||
.sorted(Comparator.comparingDouble(SemanticParseInfo::getScore).reversed())
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private List<SemanticParseInfo> getTop5CandidateParseInfo(List<SemanticParseInfo> selectedParses,
|
||||
List<SemanticParseInfo> candidateParses) {
|
||||
if (CollectionUtils.isEmpty(selectedParses) || CollectionUtils.isEmpty(candidateParses)) {
|
||||
return candidateParses;
|
||||
}
|
||||
int selectParseSize = selectedParses.size();
|
||||
Set<Double> selectParseScoreSet = selectedParses.stream()
|
||||
.map(SemanticParseInfo::getScore).collect(Collectors.toSet());
|
||||
int candidateParseSize = 5 - selectParseSize;
|
||||
candidateParses = candidateParses.stream()
|
||||
.filter(candidateParse -> !selectParseScoreSet.contains(candidateParse.getScore()))
|
||||
.collect(Collectors.toList());
|
||||
SemanticParseInfo semanticParseInfo = selectedParses.get(0);
|
||||
Long modelId = semanticParseInfo.getModelId();
|
||||
if (modelId == null || modelId <= 0) {
|
||||
return candidateParses;
|
||||
}
|
||||
return candidateParses.stream()
|
||||
.sorted(Comparator.comparing(parse -> !parse.getModelId().equals(modelId)))
|
||||
.limit(candidateParseSize)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
@TimeCost
|
||||
|
||||
Reference in New Issue
Block a user