mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 12:37:55 +00:00
(improvement)(chat) Optimize the update parserInfo code and resolve compilation exceptions (#346)
This commit is contained in:
@@ -20,7 +20,5 @@ public interface SemanticQuery {
|
|||||||
|
|
||||||
SemanticParseInfo getParseInfo();
|
SemanticParseInfo getParseInfo();
|
||||||
|
|
||||||
void updateParseInfo();
|
|
||||||
|
|
||||||
void setParseInfo(SemanticParseInfo parseInfo);
|
void setParseInfo(SemanticParseInfo parseInfo);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,28 +1,17 @@
|
|||||||
|
|
||||||
package com.tencent.supersonic.chat.query;
|
package com.tencent.supersonic.chat.query;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
|
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
|
||||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||||
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
|
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
|
||||||
import com.tencent.supersonic.common.pojo.Aggregator;
|
import com.tencent.supersonic.common.pojo.Aggregator;
|
||||||
import com.tencent.supersonic.common.pojo.DateConf;
|
|
||||||
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
|
|
||||||
import com.tencent.supersonic.common.pojo.Filter;
|
import com.tencent.supersonic.common.pojo.Filter;
|
||||||
import com.tencent.supersonic.common.pojo.Order;
|
import com.tencent.supersonic.common.pojo.Order;
|
||||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
|
||||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.FilterExpression;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
|
||||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||||
import com.tencent.supersonic.semantic.api.model.enums.QueryTypeEnum;
|
import com.tencent.supersonic.semantic.api.model.enums.QueryTypeEnum;
|
||||||
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
|
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
|
||||||
@@ -30,19 +19,13 @@ import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq;
|
|||||||
import com.tencent.supersonic.semantic.api.query.request.QueryS2QLReq;
|
import com.tencent.supersonic.semantic.api.query.request.QueryS2QLReq;
|
||||||
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
|
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Set;
|
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.ToString;
|
import lombok.ToString;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections4.CollectionUtils;
|
import org.apache.commons.collections4.CollectionUtils;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@ToString
|
@ToString
|
||||||
@@ -87,167 +70,6 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
|
|||||||
return QueryReqBuilder.buildStructReq(parseInfo);
|
return QueryReqBuilder.buildStructReq(parseInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void updateParseInfo() {
|
|
||||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
|
||||||
String logicSql = sqlInfo.getLogicSql();
|
|
||||||
if (StringUtils.isBlank(logicSql)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(logicSql);
|
|
||||||
//set dataInfo
|
|
||||||
try {
|
|
||||||
if (!org.springframework.util.CollectionUtils.isEmpty(expressions)) {
|
|
||||||
DateConf dateInfo = getDateInfo(expressions);
|
|
||||||
parseInfo.setDateInfo(dateInfo);
|
|
||||||
}
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("set dateInfo error :", e);
|
|
||||||
}
|
|
||||||
|
|
||||||
//set filter
|
|
||||||
try {
|
|
||||||
Map<String, SchemaElement> fieldNameToElement = getNameToElement(parseInfo.getModelId());
|
|
||||||
List<QueryFilter> result = getDimensionFilter(fieldNameToElement, expressions);
|
|
||||||
parseInfo.getDimensionFilters().addAll(result);
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("set dimensionFilter error :", e);
|
|
||||||
}
|
|
||||||
|
|
||||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
|
||||||
|
|
||||||
if (Objects.isNull(semanticSchema)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
List<String> allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getLogicSql()));
|
|
||||||
|
|
||||||
Set<SchemaElement> metrics = getElements(parseInfo.getModelId(), allFields, semanticSchema.getMetrics());
|
|
||||||
parseInfo.setMetrics(metrics);
|
|
||||||
|
|
||||||
if (SqlParserSelectFunctionHelper.hasAggregateFunction(sqlInfo.getLogicSql())) {
|
|
||||||
parseInfo.setNativeQuery(false);
|
|
||||||
List<String> groupByFields = SqlParserSelectHelper.getGroupByFields(sqlInfo.getLogicSql());
|
|
||||||
List<String> groupByDimensions = getFieldsExceptDate(groupByFields);
|
|
||||||
parseInfo.setDimensions(
|
|
||||||
getElements(parseInfo.getModelId(), groupByDimensions, semanticSchema.getDimensions()));
|
|
||||||
} else {
|
|
||||||
parseInfo.setNativeQuery(true);
|
|
||||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getLogicSql());
|
|
||||||
List<String> selectDimensions = getFieldsExceptDate(selectFields);
|
|
||||||
parseInfo.setDimensions(
|
|
||||||
getElements(parseInfo.getModelId(), selectDimensions, semanticSchema.getDimensions()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
private Set<SchemaElement> getElements(Long modelId, List<String> allFields, List<SchemaElement> elements) {
|
|
||||||
return elements.stream()
|
|
||||||
.filter(schemaElement -> modelId.equals(schemaElement.getModel())
|
|
||||||
&& allFields.contains(schemaElement.getName())
|
|
||||||
).collect(Collectors.toSet());
|
|
||||||
}
|
|
||||||
|
|
||||||
private List<String> getFieldsExceptDate(List<String> allFields) {
|
|
||||||
if (org.springframework.util.CollectionUtils.isEmpty(allFields)) {
|
|
||||||
return new ArrayList<>();
|
|
||||||
}
|
|
||||||
return allFields.stream()
|
|
||||||
.filter(entry -> !TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(entry))
|
|
||||||
.collect(Collectors.toList());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> fieldNameToElement,
|
|
||||||
List<FilterExpression> filterExpressions) {
|
|
||||||
List<QueryFilter> result = Lists.newArrayList();
|
|
||||||
for (FilterExpression expression : filterExpressions) {
|
|
||||||
QueryFilter dimensionFilter = new QueryFilter();
|
|
||||||
dimensionFilter.setValue(expression.getFieldValue());
|
|
||||||
SchemaElement schemaElement = fieldNameToElement.get(expression.getFieldName());
|
|
||||||
if (Objects.isNull(schemaElement)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
dimensionFilter.setName(schemaElement.getName());
|
|
||||||
dimensionFilter.setBizName(schemaElement.getBizName());
|
|
||||||
dimensionFilter.setElementID(schemaElement.getId());
|
|
||||||
|
|
||||||
FilterOperatorEnum operatorEnum = FilterOperatorEnum.getSqlOperator(expression.getOperator());
|
|
||||||
dimensionFilter.setOperator(operatorEnum);
|
|
||||||
dimensionFilter.setFunction(expression.getFunction());
|
|
||||||
result.add(dimensionFilter);
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
private DateConf getDateInfo(List<FilterExpression> filterExpressions) {
|
|
||||||
List<FilterExpression> dateExpressions = filterExpressions.stream()
|
|
||||||
.filter(expression -> TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(expression.getFieldName()))
|
|
||||||
.collect(Collectors.toList());
|
|
||||||
if (org.springframework.util.CollectionUtils.isEmpty(dateExpressions)) {
|
|
||||||
return new DateConf();
|
|
||||||
}
|
|
||||||
DateConf dateInfo = new DateConf();
|
|
||||||
dateInfo.setDateMode(DateMode.BETWEEN);
|
|
||||||
FilterExpression firstExpression = dateExpressions.get(0);
|
|
||||||
|
|
||||||
FilterOperatorEnum firstOperator = FilterOperatorEnum.getSqlOperator(firstExpression.getOperator());
|
|
||||||
if (FilterOperatorEnum.EQUALS.equals(firstOperator) && Objects.nonNull(firstExpression.getFieldValue())) {
|
|
||||||
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
|
|
||||||
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
|
|
||||||
dateInfo.setDateMode(DateMode.BETWEEN);
|
|
||||||
return dateInfo;
|
|
||||||
}
|
|
||||||
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.GREATER_THAN,
|
|
||||||
FilterOperatorEnum.GREATER_THAN_EQUALS)) {
|
|
||||||
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
|
|
||||||
if (hasSecondDate(dateExpressions)) {
|
|
||||||
dateInfo.setEndDate(dateExpressions.get(1).getFieldValue().toString());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.MINOR_THAN,
|
|
||||||
FilterOperatorEnum.MINOR_THAN_EQUALS)) {
|
|
||||||
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
|
|
||||||
if (hasSecondDate(dateExpressions)) {
|
|
||||||
dateInfo.setStartDate(dateExpressions.get(1).getFieldValue().toString());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return dateInfo;
|
|
||||||
}
|
|
||||||
|
|
||||||
private boolean containOperators(FilterExpression expression, FilterOperatorEnum firstOperator,
|
|
||||||
FilterOperatorEnum... operatorEnums) {
|
|
||||||
return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(expression.getFieldValue()));
|
|
||||||
}
|
|
||||||
|
|
||||||
private boolean hasSecondDate(List<FilterExpression> dateExpressions) {
|
|
||||||
return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue());
|
|
||||||
}
|
|
||||||
|
|
||||||
protected Map<String, SchemaElement> getNameToElement(Long modelId) {
|
|
||||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
|
||||||
List<SchemaElement> dimensions = semanticSchema.getDimensions();
|
|
||||||
List<SchemaElement> metrics = semanticSchema.getMetrics();
|
|
||||||
|
|
||||||
List<SchemaElement> allElements = Lists.newArrayList();
|
|
||||||
allElements.addAll(dimensions);
|
|
||||||
allElements.addAll(metrics);
|
|
||||||
//support alias
|
|
||||||
return allElements.stream()
|
|
||||||
.filter(schemaElement -> schemaElement.getModel().equals(modelId))
|
|
||||||
.flatMap(schemaElement -> {
|
|
||||||
Set<Pair<String, SchemaElement>> result = new HashSet<>();
|
|
||||||
result.add(Pair.of(schemaElement.getName(), schemaElement));
|
|
||||||
List<String> aliasList = schemaElement.getAlias();
|
|
||||||
if (!org.springframework.util.CollectionUtils.isEmpty(aliasList)) {
|
|
||||||
for (String alias : aliasList) {
|
|
||||||
result.add(Pair.of(alias, schemaElement));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result.stream();
|
|
||||||
})
|
|
||||||
.collect(Collectors.toMap(pair -> pair.getLeft(), pair -> pair.getRight(), (value1, value2) -> value2));
|
|
||||||
}
|
|
||||||
|
|
||||||
protected void convertBizNameToName(QueryStructReq queryStructReq) {
|
protected void convertBizNameToName(QueryStructReq queryStructReq) {
|
||||||
SchemaService schemaService = ContextUtils.getBean(SchemaService.class);
|
SchemaService schemaService = ContextUtils.getBean(SchemaService.class);
|
||||||
Map<String, String> bizNameToName = schemaService.getSemanticSchema()
|
Map<String, String> bizNameToName = schemaService.getSemanticSchema()
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import java.util.OptionalDouble;
|
|||||||
import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery;
|
import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery;
|
||||||
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
|
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
|
|
||||||
@@ -22,6 +23,7 @@ public class HeuristicQuerySelector implements QuerySelector {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SemanticQuery> select(List<SemanticQuery> candidateQueries, QueryReq queryReq) {
|
public List<SemanticQuery> select(List<SemanticQuery> candidateQueries, QueryReq queryReq) {
|
||||||
|
log.debug("pick before [{}]", candidateQueries.stream().collect(Collectors.toList()));
|
||||||
List<SemanticQuery> selectedQueries = new ArrayList<>();
|
List<SemanticQuery> selectedQueries = new ArrayList<>();
|
||||||
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
|
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
|
||||||
Double candidateThreshold = optimizationConfig.getCandidateThreshold();
|
Double candidateThreshold = optimizationConfig.getCandidateThreshold();
|
||||||
@@ -45,6 +47,7 @@ public class HeuristicQuerySelector implements QuerySelector {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
log.debug("pick after [{}]", selectedQueries.stream().collect(Collectors.toList()));
|
||||||
return selectedQueries;
|
return selectedQueries;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,17 +3,16 @@ package com.tencent.supersonic.chat.service;
|
|||||||
import com.github.pagehelper.PageInfo;
|
import com.github.pagehelper.PageInfo;
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
|
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp;
|
import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp;
|
||||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
|
import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
|
||||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
|
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
|
||||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
|
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public interface ChatService {
|
public interface ChatService {
|
||||||
@@ -49,10 +48,7 @@ public interface ChatService {
|
|||||||
|
|
||||||
void addQuery(QueryResult queryResult, ChatContext chatCtx);
|
void addQuery(QueryResult queryResult, ChatContext chatCtx);
|
||||||
|
|
||||||
List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryReq queryReq,
|
List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryReq queryReq, ParseResp parseResult);
|
||||||
ParseResp parseResult,
|
|
||||||
List<SemanticParseInfo> candidateParses,
|
|
||||||
List<SemanticParseInfo> selectedParses);
|
|
||||||
|
|
||||||
void updateChatParse(List<ChatParseDO> chatParseDOS);
|
void updateChatParse(List<ChatParseDO> chatParseDOS);
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,16 @@
|
|||||||
|
package com.tencent.supersonic.chat.service;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public interface ParseInfoService {
|
||||||
|
|
||||||
|
List<SemanticParseInfo> getTopCandidateParseInfo(List<SemanticParseInfo> selectedParses,
|
||||||
|
List<SemanticParseInfo> candidateParses);
|
||||||
|
|
||||||
|
List<SemanticParseInfo> sortParseInfo(List<SemanticQuery> semanticQueries);
|
||||||
|
|
||||||
|
void updateParseInfo(SemanticParseInfo parseInfo);
|
||||||
|
|
||||||
|
}
|
||||||
@@ -223,10 +223,9 @@ public class ChatServiceImpl implements ChatService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryReq queryReq,
|
public List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryReq queryReq, ParseResp parseResult) {
|
||||||
ParseResp parseResult,
|
List<SemanticParseInfo> candidateParses = parseResult.getCandidateParses();
|
||||||
List<SemanticParseInfo> candidateParses,
|
List<SemanticParseInfo> selectedParses = parseResult.getSelectedParses();
|
||||||
List<SemanticParseInfo> selectedParses) {
|
|
||||||
return chatQueryRepository.batchSaveParseInfo(chatCtx, queryReq, parseResult, candidateParses, selectedParses);
|
return chatQueryRepository.batchSaveParseInfo(chatCtx, queryReq, parseResult, candidateParses, selectedParses);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,234 @@
|
|||||||
|
|
||||||
|
package com.tencent.supersonic.chat.service.impl;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||||
|
import com.tencent.supersonic.chat.service.ParseInfoService;
|
||||||
|
import com.tencent.supersonic.common.pojo.DateConf;
|
||||||
|
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.FilterExpression;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||||
|
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
@Service
|
||||||
|
public class ParserInfoServiceImpl implements ParseInfoService {
|
||||||
|
|
||||||
|
@Value("${candidate.top.size:5}")
|
||||||
|
private int candidateTopSize;
|
||||||
|
|
||||||
|
public List<SemanticParseInfo> getTopCandidateParseInfo(List<SemanticParseInfo> selectedParses,
|
||||||
|
List<SemanticParseInfo> candidateParses) {
|
||||||
|
if (CollectionUtils.isEmpty(selectedParses) || CollectionUtils.isEmpty(candidateParses)) {
|
||||||
|
return candidateParses;
|
||||||
|
}
|
||||||
|
int selectParseSize = selectedParses.size();
|
||||||
|
Set<Double> selectParseScoreSet = selectedParses.stream()
|
||||||
|
.map(SemanticParseInfo::getScore).collect(Collectors.toSet());
|
||||||
|
int candidateParseSize = candidateTopSize - selectParseSize;
|
||||||
|
candidateParses = candidateParses.stream()
|
||||||
|
.filter(candidateParse -> !selectParseScoreSet.contains(candidateParse.getScore()))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
SemanticParseInfo semanticParseInfo = selectedParses.get(0);
|
||||||
|
Long modelId = semanticParseInfo.getModelId();
|
||||||
|
if (modelId == null || modelId <= 0) {
|
||||||
|
return candidateParses;
|
||||||
|
}
|
||||||
|
return candidateParses.stream()
|
||||||
|
.sorted(Comparator.comparing(parse -> !parse.getModelId().equals(modelId)))
|
||||||
|
.limit(candidateParseSize)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<SemanticParseInfo> sortParseInfo(List<SemanticQuery> semanticQueries) {
|
||||||
|
return semanticQueries.stream()
|
||||||
|
.map(SemanticQuery::getParseInfo)
|
||||||
|
.sorted(Comparator.comparingDouble(SemanticParseInfo::getScore).reversed())
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void updateParseInfo(SemanticParseInfo parseInfo) {
|
||||||
|
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||||
|
String logicSql = sqlInfo.getLogicSql();
|
||||||
|
if (StringUtils.isBlank(logicSql)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(logicSql);
|
||||||
|
//set dataInfo
|
||||||
|
try {
|
||||||
|
if (!org.springframework.util.CollectionUtils.isEmpty(expressions)) {
|
||||||
|
DateConf dateInfo = getDateInfo(expressions);
|
||||||
|
parseInfo.setDateInfo(dateInfo);
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("set dateInfo error :", e);
|
||||||
|
}
|
||||||
|
|
||||||
|
//set filter
|
||||||
|
try {
|
||||||
|
Map<String, SchemaElement> fieldNameToElement = getNameToElement(parseInfo.getModelId());
|
||||||
|
List<QueryFilter> result = getDimensionFilter(fieldNameToElement, expressions);
|
||||||
|
parseInfo.getDimensionFilters().addAll(result);
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("set dimensionFilter error :", e);
|
||||||
|
}
|
||||||
|
|
||||||
|
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||||
|
|
||||||
|
if (Objects.isNull(semanticSchema)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
List<String> allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getLogicSql()));
|
||||||
|
|
||||||
|
Set<SchemaElement> metrics = getElements(parseInfo.getModelId(), allFields, semanticSchema.getMetrics());
|
||||||
|
parseInfo.setMetrics(metrics);
|
||||||
|
|
||||||
|
if (SqlParserSelectFunctionHelper.hasAggregateFunction(sqlInfo.getLogicSql())) {
|
||||||
|
parseInfo.setNativeQuery(false);
|
||||||
|
List<String> groupByFields = SqlParserSelectHelper.getGroupByFields(sqlInfo.getLogicSql());
|
||||||
|
List<String> groupByDimensions = getFieldsExceptDate(groupByFields);
|
||||||
|
parseInfo.setDimensions(
|
||||||
|
getElements(parseInfo.getModelId(), groupByDimensions, semanticSchema.getDimensions()));
|
||||||
|
} else {
|
||||||
|
parseInfo.setNativeQuery(true);
|
||||||
|
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getLogicSql());
|
||||||
|
List<String> selectDimensions = getFieldsExceptDate(selectFields);
|
||||||
|
parseInfo.setDimensions(
|
||||||
|
getElements(parseInfo.getModelId(), selectDimensions, semanticSchema.getDimensions()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private Set<SchemaElement> getElements(Long modelId, List<String> allFields, List<SchemaElement> elements) {
|
||||||
|
return elements.stream()
|
||||||
|
.filter(schemaElement -> modelId.equals(schemaElement.getModel())
|
||||||
|
&& allFields.contains(schemaElement.getName())
|
||||||
|
).collect(Collectors.toSet());
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<String> getFieldsExceptDate(List<String> allFields) {
|
||||||
|
if (org.springframework.util.CollectionUtils.isEmpty(allFields)) {
|
||||||
|
return new ArrayList<>();
|
||||||
|
}
|
||||||
|
return allFields.stream()
|
||||||
|
.filter(entry -> !TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(entry))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> fieldNameToElement,
|
||||||
|
List<FilterExpression> filterExpressions) {
|
||||||
|
List<QueryFilter> result = Lists.newArrayList();
|
||||||
|
for (FilterExpression expression : filterExpressions) {
|
||||||
|
QueryFilter dimensionFilter = new QueryFilter();
|
||||||
|
dimensionFilter.setValue(expression.getFieldValue());
|
||||||
|
SchemaElement schemaElement = fieldNameToElement.get(expression.getFieldName());
|
||||||
|
if (Objects.isNull(schemaElement)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
dimensionFilter.setName(schemaElement.getName());
|
||||||
|
dimensionFilter.setBizName(schemaElement.getBizName());
|
||||||
|
dimensionFilter.setElementID(schemaElement.getId());
|
||||||
|
|
||||||
|
FilterOperatorEnum operatorEnum = FilterOperatorEnum.getSqlOperator(expression.getOperator());
|
||||||
|
dimensionFilter.setOperator(operatorEnum);
|
||||||
|
dimensionFilter.setFunction(expression.getFunction());
|
||||||
|
result.add(dimensionFilter);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private DateConf getDateInfo(List<FilterExpression> filterExpressions) {
|
||||||
|
List<FilterExpression> dateExpressions = filterExpressions.stream()
|
||||||
|
.filter(expression -> TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(expression.getFieldName()))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
if (org.springframework.util.CollectionUtils.isEmpty(dateExpressions)) {
|
||||||
|
return new DateConf();
|
||||||
|
}
|
||||||
|
DateConf dateInfo = new DateConf();
|
||||||
|
dateInfo.setDateMode(DateMode.BETWEEN);
|
||||||
|
FilterExpression firstExpression = dateExpressions.get(0);
|
||||||
|
|
||||||
|
FilterOperatorEnum firstOperator = FilterOperatorEnum.getSqlOperator(firstExpression.getOperator());
|
||||||
|
if (FilterOperatorEnum.EQUALS.equals(firstOperator) && Objects.nonNull(firstExpression.getFieldValue())) {
|
||||||
|
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
|
||||||
|
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
|
||||||
|
dateInfo.setDateMode(DateMode.BETWEEN);
|
||||||
|
return dateInfo;
|
||||||
|
}
|
||||||
|
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.GREATER_THAN,
|
||||||
|
FilterOperatorEnum.GREATER_THAN_EQUALS)) {
|
||||||
|
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
|
||||||
|
if (hasSecondDate(dateExpressions)) {
|
||||||
|
dateInfo.setEndDate(dateExpressions.get(1).getFieldValue().toString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.MINOR_THAN,
|
||||||
|
FilterOperatorEnum.MINOR_THAN_EQUALS)) {
|
||||||
|
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
|
||||||
|
if (hasSecondDate(dateExpressions)) {
|
||||||
|
dateInfo.setStartDate(dateExpressions.get(1).getFieldValue().toString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return dateInfo;
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean containOperators(FilterExpression expression, FilterOperatorEnum firstOperator,
|
||||||
|
FilterOperatorEnum... operatorEnums) {
|
||||||
|
return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(expression.getFieldValue()));
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean hasSecondDate(List<FilterExpression> dateExpressions) {
|
||||||
|
return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
protected Map<String, SchemaElement> getNameToElement(Long modelId) {
|
||||||
|
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||||
|
List<SchemaElement> dimensions = semanticSchema.getDimensions();
|
||||||
|
List<SchemaElement> metrics = semanticSchema.getMetrics();
|
||||||
|
|
||||||
|
List<SchemaElement> allElements = Lists.newArrayList();
|
||||||
|
allElements.addAll(dimensions);
|
||||||
|
allElements.addAll(metrics);
|
||||||
|
//support alias
|
||||||
|
return allElements.stream()
|
||||||
|
.filter(schemaElement -> schemaElement.getModel().equals(modelId))
|
||||||
|
.flatMap(schemaElement -> {
|
||||||
|
Set<Pair<String, SchemaElement>> result = new HashSet<>();
|
||||||
|
result.add(Pair.of(schemaElement.getName(), schemaElement));
|
||||||
|
List<String> aliasList = schemaElement.getAlias();
|
||||||
|
if (!org.springframework.util.CollectionUtils.isEmpty(aliasList)) {
|
||||||
|
for (String alias : aliasList) {
|
||||||
|
result.add(Pair.of(alias, schemaElement));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result.stream();
|
||||||
|
})
|
||||||
|
.collect(Collectors.toMap(pair -> pair.getLeft(), pair -> pair.getRight(), (value1, value2) -> value2));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -31,6 +31,7 @@ import com.tencent.supersonic.chat.query.llm.s2ql.S2QLQuery;
|
|||||||
import com.tencent.supersonic.chat.responder.execute.ExecuteResponder;
|
import com.tencent.supersonic.chat.responder.execute.ExecuteResponder;
|
||||||
import com.tencent.supersonic.chat.responder.parse.ParseResponder;
|
import com.tencent.supersonic.chat.responder.parse.ParseResponder;
|
||||||
import com.tencent.supersonic.chat.service.ChatService;
|
import com.tencent.supersonic.chat.service.ChatService;
|
||||||
|
import com.tencent.supersonic.chat.service.ParseInfoService;
|
||||||
import com.tencent.supersonic.chat.service.QueryService;
|
import com.tencent.supersonic.chat.service.QueryService;
|
||||||
import com.tencent.supersonic.chat.service.SemanticService;
|
import com.tencent.supersonic.chat.service.SemanticService;
|
||||||
import com.tencent.supersonic.chat.service.StatisticsService;
|
import com.tencent.supersonic.chat.service.StatisticsService;
|
||||||
@@ -58,7 +59,6 @@ import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
|||||||
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
|
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
|
||||||
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
|
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Comparator;
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -102,6 +102,8 @@ public class QueryServiceImpl implements QueryService {
|
|||||||
private StatisticsService statisticsService;
|
private StatisticsService statisticsService;
|
||||||
@Autowired
|
@Autowired
|
||||||
private SolvedQueryManager solvedQueryManager;
|
private SolvedQueryManager solvedQueryManager;
|
||||||
|
@Autowired
|
||||||
|
private ParseInfoService parseInfoService;
|
||||||
|
|
||||||
@Value("${time.threshold: 100}")
|
@Value("${time.threshold: 100}")
|
||||||
private Integer timeThreshold;
|
private Integer timeThreshold;
|
||||||
@@ -148,7 +150,6 @@ public class QueryServiceImpl implements QueryService {
|
|||||||
semanticCorrectors.stream().forEach(correction -> {
|
semanticCorrectors.stream().forEach(correction -> {
|
||||||
correction.correct(queryReq, semanticQuery.getParseInfo());
|
correction.correct(queryReq, semanticQuery.getParseInfo());
|
||||||
});
|
});
|
||||||
semanticQuery.updateParseInfo();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -156,15 +157,14 @@ public class QueryServiceImpl implements QueryService {
|
|||||||
ParseResp parseResult;
|
ParseResp parseResult;
|
||||||
List<ChatParseDO> chatParseDOS = Lists.newArrayList();
|
List<ChatParseDO> chatParseDOS = Lists.newArrayList();
|
||||||
if (candidateQueries.size() > 0) {
|
if (candidateQueries.size() > 0) {
|
||||||
log.debug("pick before [{}]", candidateQueries.stream().collect(
|
|
||||||
Collectors.toList()));
|
|
||||||
List<SemanticQuery> selectedQueries = querySelector.select(candidateQueries, queryReq);
|
List<SemanticQuery> selectedQueries = querySelector.select(candidateQueries, queryReq);
|
||||||
log.debug("pick after [{}]", selectedQueries.stream().collect(
|
|
||||||
Collectors.toList()));
|
|
||||||
|
|
||||||
List<SemanticParseInfo> selectedParses = convertParseInfo(selectedQueries);
|
candidateQueries.forEach(semanticQuery -> parseInfoService.updateParseInfo(semanticQuery.getParseInfo()));
|
||||||
List<SemanticParseInfo> candidateParses = convertParseInfo(candidateQueries);
|
List<SemanticParseInfo> selectedParses = parseInfoService.sortParseInfo(selectedQueries);
|
||||||
candidateParses = getTop5CandidateParseInfo(selectedParses, candidateParses);
|
List<SemanticParseInfo> candidateParses = parseInfoService.sortParseInfo(candidateQueries);
|
||||||
|
|
||||||
|
candidateParses = parseInfoService.getTopCandidateParseInfo(selectedParses, candidateParses);
|
||||||
|
|
||||||
parseResult = ParseResp.builder()
|
parseResult = ParseResp.builder()
|
||||||
.chatId(queryReq.getChatId())
|
.chatId(queryReq.getChatId())
|
||||||
.queryText(queryReq.getQueryText())
|
.queryText(queryReq.getQueryText())
|
||||||
@@ -172,7 +172,7 @@ public class QueryServiceImpl implements QueryService {
|
|||||||
.selectedParses(selectedParses)
|
.selectedParses(selectedParses)
|
||||||
.candidateParses(candidateParses)
|
.candidateParses(candidateParses)
|
||||||
.build();
|
.build();
|
||||||
chatParseDOS = chatService.batchAddParse(chatCtx, queryReq, parseResult, candidateParses, selectedParses);
|
chatParseDOS = chatService.batchAddParse(chatCtx, queryReq, parseResult);
|
||||||
} else {
|
} else {
|
||||||
parseResult = ParseResp.builder()
|
parseResult = ParseResp.builder()
|
||||||
.chatId(queryReq.getChatId())
|
.chatId(queryReq.getChatId())
|
||||||
@@ -198,35 +198,6 @@ public class QueryServiceImpl implements QueryService {
|
|||||||
return parseResult;
|
return parseResult;
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<SemanticParseInfo> convertParseInfo(List<SemanticQuery> semanticQueries) {
|
|
||||||
return semanticQueries.stream()
|
|
||||||
.map(SemanticQuery::getParseInfo)
|
|
||||||
.sorted(Comparator.comparingDouble(SemanticParseInfo::getScore).reversed())
|
|
||||||
.collect(Collectors.toList());
|
|
||||||
}
|
|
||||||
|
|
||||||
private List<SemanticParseInfo> getTop5CandidateParseInfo(List<SemanticParseInfo> selectedParses,
|
|
||||||
List<SemanticParseInfo> candidateParses) {
|
|
||||||
if (CollectionUtils.isEmpty(selectedParses) || CollectionUtils.isEmpty(candidateParses)) {
|
|
||||||
return candidateParses;
|
|
||||||
}
|
|
||||||
int selectParseSize = selectedParses.size();
|
|
||||||
Set<Double> selectParseScoreSet = selectedParses.stream()
|
|
||||||
.map(SemanticParseInfo::getScore).collect(Collectors.toSet());
|
|
||||||
int candidateParseSize = 5 - selectParseSize;
|
|
||||||
candidateParses = candidateParses.stream()
|
|
||||||
.filter(candidateParse -> !selectParseScoreSet.contains(candidateParse.getScore()))
|
|
||||||
.collect(Collectors.toList());
|
|
||||||
SemanticParseInfo semanticParseInfo = selectedParses.get(0);
|
|
||||||
Long modelId = semanticParseInfo.getModelId();
|
|
||||||
if (modelId == null || modelId <= 0) {
|
|
||||||
return candidateParses;
|
|
||||||
}
|
|
||||||
return candidateParses.stream()
|
|
||||||
.sorted(Comparator.comparing(parse -> !parse.getModelId().equals(modelId)))
|
|
||||||
.limit(candidateParseSize)
|
|
||||||
.collect(Collectors.toList());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@TimeCost
|
@TimeCost
|
||||||
|
|||||||
Reference in New Issue
Block a user