(improvement)(chat) set native value in parser and execute , optimized the executeDirectQuery code (#196)

This commit is contained in:
lexluo09
2023-10-12 11:38:48 +08:00
committed by GitHub
parent e6f2ce2598
commit b753eda9b9
5 changed files with 165 additions and 136 deletions

View File

@@ -12,6 +12,7 @@ import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
public class EntityInfoExecuteResponder implements ExecuteResponder {
@@ -25,12 +26,9 @@ public class EntityInfoExecuteResponder implements ExecuteResponder {
EntityInfo entityInfo = semanticService.getEntityInfo(semanticParseInfo, user);
queryResult.setEntityInfo(entityInfo);
if (Objects.isNull(entityInfo) || Objects.isNull(entityInfo.getModelInfo())
|| Objects.isNull(entityInfo.getModelInfo().getPrimaryEntityName())) {
return;
}
String primaryEntityBizName = entityInfo.getModelInfo().getPrimaryEntityBizName();
if (CollectionUtils.isEmpty(queryResult.getQueryColumns())) {
String primaryEntityBizName = semanticService.getPrimaryEntityBizName(entityInfo);
if (StringUtils.isEmpty(primaryEntityBizName)
|| CollectionUtils.isEmpty(queryResult.getQueryColumns())) {
return;
}
boolean existPrimaryEntityName = queryResult.getQueryColumns().stream()

View File

@@ -8,8 +8,9 @@ import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.util.ContextUtils;
import org.springframework.util.CollectionUtils;
import java.util.List;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
public class EntityInfoParseResponder implements ParseResponder {
@@ -21,13 +22,21 @@ public class EntityInfoParseResponder implements ParseResponder {
}
QueryReq queryReq = queryContext.getRequest();
selectedParses.forEach(parseInfo -> {
String queryMode = parseInfo.getQueryMode();
if (QueryManager.isEntityQuery(queryMode)) {
EntityInfo entityInfo = ContextUtils.getBean(SemanticService.class)
.getEntityInfo(parseInfo, queryReq.getUser());
//1. set entity info
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, queryReq.getUser());
if (QueryManager.isEntityQuery(parseInfo.getQueryMode())) {
parseInfo.setEntityInfo(entityInfo);
}
//2. set native value
String primaryEntityBizName = semanticService.getPrimaryEntityBizName(entityInfo);
if (StringUtils.isNotEmpty(primaryEntityBizName)) {
//if exist primaryEntityBizName in parseInfo's dimensions, set nativeQuery to true
boolean existPrimaryEntityBizName = parseInfo.getDimensions().stream()
.anyMatch(schemaElement -> primaryEntityBizName.equalsIgnoreCase(schemaElement.getBizName()));
parseInfo.setNativeQuery(existPrimaryEntityBizName);
}
});
}
}

View File

@@ -189,6 +189,13 @@ public class SemanticService {
return entityInfo;
}
public String getPrimaryEntityBizName(EntityInfo entityInfo) {
if (Objects.isNull(entityInfo) || Objects.isNull(entityInfo.getModelInfo())) {
return null;
}
return entityInfo.getModelInfo().getPrimaryEntityBizName();
}
public void setMainModel(EntityInfo modelInfo, Long model, String entity, User user) {
if (StringUtils.isEmpty(entity)) {
return;

View File

@@ -1,66 +1,63 @@
package com.tencent.supersonic.chat.service.impl;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.request.SolvedQueryReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.persistence.dataobject.CostType;
import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO;
import com.tencent.supersonic.chat.query.QuerySelector;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.QuerySelector;
import com.tencent.supersonic.chat.query.llm.dsl.DslQuery;
import com.tencent.supersonic.chat.query.llm.dsl.LLMResp;
import com.tencent.supersonic.chat.responder.execute.ExecuteResponder;
import com.tencent.supersonic.chat.responder.parse.ParseResponder;
import com.tencent.supersonic.chat.service.ChatService;
import com.tencent.supersonic.chat.service.QueryService;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.knowledge.service.SearchService;
import com.tencent.supersonic.chat.service.StatisticsService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import java.util.List;
import java.util.ArrayList;
import java.util.Map;
import java.util.HashMap;
import java.util.Objects;
import java.util.Set;
import java.util.HashSet;
import java.util.Comparator;
import com.tencent.supersonic.chat.utils.SolvedQueryManager;
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import com.tencent.supersonic.common.util.jsqlparser.FilterExpression;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import java.util.stream.Collectors;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.jsqlparser.FilterExpression;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.knowledge.service.SearchService;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import com.tencent.supersonic.semantic.query.utils.QueryStructUtils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.commons.collections.CollectionUtils;
@@ -83,8 +80,6 @@ public class QueryServiceImpl implements QueryService {
private StatisticsService statisticsService;
@Autowired
private SolvedQueryManager solvedQueryManager;
@Autowired
private SearchService searchService;
@Value("${time.threshold: 100}")
private Integer timeThreshold;
@@ -205,8 +200,8 @@ public class QueryServiceImpl implements QueryService {
}
public void saveInfo(List<StatisticsDO> timeCostDOList,
String queryText, Long queryId,
String userName, Long chatId) {
String queryText, Long queryId,
String userName, Long chatId) {
List<StatisticsDO> list = timeCostDOList.stream()
.filter(o -> o.getCost() > timeThreshold).collect(Collectors.toList());
list.forEach(o -> {
@@ -273,110 +268,35 @@ public class QueryServiceImpl implements QueryService {
public QueryResult executeDirectQuery(QueryDataReq queryData, User user) throws SqlParseException {
ChatParseDO chatParseDO = chatService.getParseInfo(queryData.getQueryId(),
queryData.getUser().getName(), queryData.getParseId());
SemanticParseInfo parseInfo = JsonUtil.toObject(chatParseDO.getParseInfo(), SemanticParseInfo.class);
SemanticParseInfo parseInfo = getSemanticParseInfo(queryData, chatParseDO);
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
if (!parseInfo.getQueryMode().equals(DslQuery.QUERY_MODE)) {
if (CollectionUtils.isNotEmpty(queryData.getDimensions())) {
parseInfo.setDimensions(queryData.getDimensions());
}
if (CollectionUtils.isNotEmpty(queryData.getMetrics())) {
parseInfo.setMetrics(queryData.getMetrics());
}
if (CollectionUtils.isNotEmpty(queryData.getDimensionFilters())) {
parseInfo.setDimensionFilters(queryData.getDimensionFilters());
}
if (Objects.nonNull(queryData.getDateInfo())) {
parseInfo.setDateInfo(queryData.getDateInfo());
}
}
if (parseInfo.getQueryMode().equals(DslQuery.QUERY_MODE)
&& (CollectionUtils.isNotEmpty(queryData.getDimensionFilters())
|| CollectionUtils.isNotEmpty(queryData.getMetricFilters()))) {
if (DslQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
String json = JsonUtil.toString(parseInfo.getProperties().get(Constants.CONTEXT));
DSLParseResult dslParseResult = JsonUtil.toObject(json, DSLParseResult.class);
LLMResp llmResp = dslParseResult.getLlmResp();
String correctorSql = llmResp.getCorrectorSql();
log.info("correctorSql before replacing:{}", correctorSql);
List<FilterExpression> filterExpressionList = SqlParserSelectHelper.getFilterExpression(correctorSql);
for (QueryFilter dslQueryFilter : queryData.getDimensionFilters()) {
Map<String, String> map = new HashMap<>();
for (FilterExpression filterExpression : filterExpressionList) {
if (filterExpression.getFieldName() != null
&& filterExpression.getFieldName().equals(dslQueryFilter.getName())
&& dslQueryFilter.getOperator().getValue().equals(filterExpression.getOperator())) {
map.put(filterExpression.getFieldValue().toString(), dslQueryFilter.getValue().toString());
parseInfo.getDimensionFilters().stream().forEach(o -> {
if (o.getName().equals(dslQueryFilter.getName())) {
o.setValue(dslQueryFilter.getValue());
}
});
break;
}
}
filedNameToValueMap.put(dslQueryFilter.getName(), map);
}
for (QueryFilter dslQueryFilter : queryData.getMetricFilters()) {
Map<String, String> map = new HashMap<>();
for (FilterExpression filterExpression : filterExpressionList) {
if (filterExpression.getFieldName() != null
&& filterExpression.getFieldName().equals(dslQueryFilter.getName())
&& dslQueryFilter.getOperator().getValue().equals(filterExpression.getOperator())) {
map.put(filterExpression.getFieldValue().toString(), dslQueryFilter.getValue().toString());
parseInfo.getMetricFilters().stream().forEach(o -> {
if (o.getName().equals(dslQueryFilter.getName())) {
o.setValue(dslQueryFilter.getValue());
}
});
break;
}
}
filedNameToValueMap.put(dslQueryFilter.getName(), map);
}
String dateField = "sys_imp_date";
if (Objects.nonNull(queryData.getDateInfo())) {
Map<String, String> map = new HashMap<>();
List<String> dateFields = Lists.newArrayList("dayno", "sys_imp_date", "sys_imp_week", "sys_imp_month");
if (queryData.getDateInfo().getStartDate().equals(queryData.getDateInfo().getEndDate())) {
for (FilterExpression filterExpression : filterExpressionList) {
if (filterExpression.getFieldName() != null
&& dateFields.contains(filterExpression.getFieldName())) {
dateField = filterExpression.getFieldName();
map.put(filterExpression.getFieldValue().toString(),
queryData.getDateInfo().getStartDate());
break;
}
}
} else {
for (FilterExpression filterExpression : filterExpressionList) {
if (dateFields.contains(filterExpression.getFieldName())) {
dateField = filterExpression.getFieldName();
if (filterExpression.getOperator().equals(">=")
|| filterExpression.getOperator().equals(">")) {
map.put(filterExpression.getFieldValue().toString(),
queryData.getDateInfo().getStartDate());
}
if (filterExpression.getOperator().equals("<=")
|| filterExpression.getOperator().equals("<")) {
map.put(filterExpression.getFieldValue().toString(),
queryData.getDateInfo().getEndDate());
}
}
}
}
filedNameToValueMap.put(dateField, map);
parseInfo.setDateInfo(queryData.getDateInfo());
}
updateFilters(filedNameToValueMap, filterExpressionList, queryData.getDimensionFilters(),
parseInfo.getDimensionFilters());
updateFilters(filedNameToValueMap, filterExpressionList, queryData.getMetricFilters(),
parseInfo.getMetricFilters());
updateDateInfo(queryData, parseInfo, filedNameToValueMap, filterExpressionList);
log.info("filedNameToValueMap:{}", filedNameToValueMap);
correctorSql = SqlParserUpdateHelper.replaceValue(correctorSql, filedNameToValueMap);
log.info("correctorSql after replacing:{}", correctorSql);
llmResp.setCorrectorSql(correctorSql);
dslParseResult.setLlmResp(llmResp);
Map<String, Object> properties = new HashMap<>();
properties.put(Constants.CONTEXT, dslParseResult);
parseInfo.setProperties(properties);
parseInfo.getSqlInfo().setLogicSql(correctorSql);
semanticQuery.setParseInfo(parseInfo);
ExplainResp explain = semanticQuery.explain(user);
if (!Objects.isNull(explain)) {
parseInfo.getSqlInfo().setQuerySql(explain.getSql());
@@ -388,6 +308,91 @@ public class QueryServiceImpl implements QueryService {
return queryResult;
}
private void updateDateInfo(QueryDataReq queryData, SemanticParseInfo parseInfo,
Map<String, Map<String, String>> filedNameToValueMap, List<FilterExpression> filterExpressionList) {
if (Objects.isNull(queryData.getDateInfo())) {
return;
}
Map<String, String> map = new HashMap<>();
List<String> dateFields = new ArrayList<>(QueryStructUtils.internalTimeCols);
String dateField = TimeDimensionEnum.DAY.getName();
if (queryData.getDateInfo().getStartDate().equals(queryData.getDateInfo().getEndDate())) {
for (FilterExpression filterExpression : filterExpressionList) {
if (filterExpression.getFieldName() != null
&& dateFields.contains(filterExpression.getFieldName())) {
dateField = filterExpression.getFieldName();
map.put(filterExpression.getFieldValue().toString(),
queryData.getDateInfo().getStartDate());
break;
}
}
} else {
for (FilterExpression filterExpression : filterExpressionList) {
if (dateFields.contains(filterExpression.getFieldName())) {
dateField = filterExpression.getFieldName();
if (FilterOperatorEnum.GREATER_THAN_EQUALS.getValue().equals(filterExpression.getOperator())
|| FilterOperatorEnum.GREATER_THAN.getValue().equals(filterExpression.getOperator())) {
map.put(filterExpression.getFieldValue().toString(),
queryData.getDateInfo().getStartDate());
}
if (FilterOperatorEnum.MINOR_THAN_EQUALS.getValue().equals(filterExpression.getOperator())
|| FilterOperatorEnum.MINOR_THAN.getValue().equals(filterExpression.getOperator())) {
map.put(filterExpression.getFieldValue().toString(),
queryData.getDateInfo().getEndDate());
}
}
}
}
filedNameToValueMap.put(dateField, map);
parseInfo.setDateInfo(queryData.getDateInfo());
}
private void updateFilters(Map<String, Map<String, String>> filedNameToValueMap,
List<FilterExpression> filterExpressionList, Set<QueryFilter> metricFilters,
Set<QueryFilter> contextMetricFilters) {
if (CollectionUtils.isEmpty(metricFilters)) {
return;
}
for (QueryFilter dslQueryFilter : metricFilters) {
Map<String, String> map = new HashMap<>();
for (FilterExpression filterExpression : filterExpressionList) {
if (filterExpression.getFieldName() != null
&& filterExpression.getFieldName().equals(dslQueryFilter.getName())
&& dslQueryFilter.getOperator().getValue().equals(filterExpression.getOperator())) {
map.put(filterExpression.getFieldValue().toString(), dslQueryFilter.getValue().toString());
contextMetricFilters.stream().forEach(o -> {
if (o.getName().equals(dslQueryFilter.getName())) {
o.setValue(dslQueryFilter.getValue());
}
});
break;
}
}
filedNameToValueMap.put(dslQueryFilter.getName(), map);
}
}
private SemanticParseInfo getSemanticParseInfo(QueryDataReq queryData, ChatParseDO chatParseDO) {
SemanticParseInfo parseInfo = JsonUtil.toObject(chatParseDO.getParseInfo(), SemanticParseInfo.class);
if (DslQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
return parseInfo;
}
if (CollectionUtils.isNotEmpty(queryData.getDimensions())) {
parseInfo.setDimensions(queryData.getDimensions());
}
if (CollectionUtils.isNotEmpty(queryData.getMetrics())) {
parseInfo.setMetrics(queryData.getMetrics());
}
if (CollectionUtils.isNotEmpty(queryData.getDimensionFilters())) {
parseInfo.setDimensionFilters(queryData.getDimensionFilters());
}
if (Objects.nonNull(queryData.getDateInfo())) {
parseInfo.setDateInfo(queryData.getDateInfo());
}
return parseInfo;
}
@Override
public Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception {
QueryStructReq queryStructReq = new QueryStructReq();

View File

@@ -44,8 +44,16 @@ import org.springframework.util.CollectionUtils;
@Slf4j
@Component
public class QueryStructUtils {
public static Set<String> internalCols = new HashSet<>(
Arrays.asList("dayno", "plat_sys_var", "sys_imp_date", "sys_imp_week", "sys_imp_month"));
public static Set<String> internalTimeCols = new HashSet<>(
Arrays.asList("dayno", "sys_imp_date", "sys_imp_week", "sys_imp_month"));
public static Set<String> internalCols;
static {
internalCols = new HashSet<>(Arrays.asList("plat_sys_var"));
internalCols.addAll(internalTimeCols);
}
private final DateUtils dateUtils;
private final SqlFilterUtils sqlFilterUtils;
private final Catalog catalog;
@@ -160,11 +168,13 @@ public class QueryStructUtils {
sqlFilterUtils.getFiltersCol(queryStructCmd.getOriginalFilter()).stream().forEach(col -> resNameEnSet.add(col));
return resNameEnSet;
}
public Set<String> getResName(QueryDslReq queryDslReq) {
Set<String> resNameSet = SqlParserSelectHelper.getAllFields(queryDslReq.getSql())
.stream().collect(Collectors.toSet());
.stream().collect(Collectors.toSet());
return resNameSet;
}
public Set<String> getResNameEnExceptInternalCol(QueryStructReq queryStructCmd) {
Set<String> resNameEnSet = getResNameEn(queryStructCmd);
return resNameEnSet.stream().filter(res -> !internalCols.contains(res)).collect(Collectors.toSet());