(improvement)(chat) Rules and large model queries should be queried using s2sql (#334)

This commit is contained in:
lexluo09
2023-11-07 16:23:31 +08:00
committed by GitHub
parent aa6c658a9a
commit 0365886270
37 changed files with 340 additions and 153 deletions

View File

@@ -5,8 +5,8 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
@@ -47,7 +47,14 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
return elements.stream();
})
.collect(Collectors.toMap(a -> a, a -> a, (k1, k2) -> k1));
result.put(DateUtils.DATE_FIELD, DateUtils.DATE_FIELD);
result.put(TimeDimensionEnum.DAY.getChName(), TimeDimensionEnum.DAY.getChName());
result.put(TimeDimensionEnum.MONTH.getChName(), TimeDimensionEnum.MONTH.getChName());
result.put(TimeDimensionEnum.WEEK.getChName(), TimeDimensionEnum.WEEK.getChName());
result.put(TimeDimensionEnum.DAY.getName(), TimeDimensionEnum.DAY.getChName());
result.put(TimeDimensionEnum.MONTH.getName(), TimeDimensionEnum.MONTH.getChName());
result.put(TimeDimensionEnum.WEEK.getName(), TimeDimensionEnum.WEEK.getChName());
return result;
}
@@ -61,7 +68,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
}
needAddFields.removeAll(selectFields);
needAddFields.remove(DateUtils.DATE_FIELD);
needAddFields.remove(TimeDimensionEnum.DAY.getChName());
String replaceFields = SqlParserAddHelper.addFieldsToSelect(sql, new ArrayList<>(needAddFields));
semanticCorrectInfo.setSql(replaceFields);
}

View File

@@ -0,0 +1,13 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
public interface CorrectorService {
SemanticCorrectInfo correctorSql(QueryFilters queryFilters, SemanticParseInfo parseInfo, String sql);
void addS2QLAndLoginSql(QueryStructReq queryStructReq, SemanticParseInfo parseInfo);
}

View File

@@ -0,0 +1,95 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.component.SemanticCorrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.Aggregator;
import com.tencent.supersonic.common.pojo.Filter;
import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.semantic.api.query.request.QueryS2QLReq;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@Slf4j
@Service
public class CorrectorServiceImpl implements CorrectorService {
@Autowired
private SchemaService schemaService;
public SemanticCorrectInfo correctorSql(QueryFilters queryFilters, SemanticParseInfo parseInfo, String sql) {
SemanticCorrectInfo correctInfo = SemanticCorrectInfo.builder()
.queryFilters(queryFilters).sql(sql)
.parseInfo(parseInfo).build();
List<SemanticCorrector> corrections = ComponentFactory.getSqlCorrections();
corrections.forEach(correction -> {
try {
correction.correct(correctInfo);
log.info("sqlCorrection:{} sql:{}", correction.getClass().getSimpleName(), correctInfo.getSql());
} catch (Exception e) {
log.error(String.format("correct error,correctInfo:%s", correctInfo), e);
}
});
return correctInfo;
}
public void addS2QLAndLoginSql(QueryStructReq queryStructReq, SemanticParseInfo parseInfo) {
convertBizNameToName(queryStructReq, parseInfo);
QueryS2QLReq queryS2QLReq = queryStructReq.convert(queryStructReq);
parseInfo.getSqlInfo().setS2QL(queryS2QLReq.getSql());
queryStructReq.setS2QL(queryS2QLReq.getSql());
SemanticCorrectInfo semanticCorrectInfo = correctorSql(new QueryFilters(), parseInfo,
queryS2QLReq.getSql());
parseInfo.getSqlInfo().setLogicSql(semanticCorrectInfo.getSql());
queryStructReq.setLogicSql(semanticCorrectInfo.getSql());
}
private void convertBizNameToName(QueryStructReq queryStructReq, SemanticParseInfo parseInfo) {
Map<String, String> bizNameToName = schemaService.getSemanticSchema()
.getBizNameToName(queryStructReq.getModelId());
List<Order> orders = queryStructReq.getOrders();
if (CollectionUtils.isNotEmpty(orders)) {
for (Order order : orders) {
order.setColumn(bizNameToName.get(order.getColumn()));
}
}
List<Aggregator> aggregators = queryStructReq.getAggregators();
if (CollectionUtils.isNotEmpty(aggregators)) {
for (Aggregator aggregator : aggregators) {
aggregator.setColumn(bizNameToName.get(aggregator.getColumn()));
}
}
List<String> groups = queryStructReq.getGroups();
if (CollectionUtils.isNotEmpty(groups)) {
groups = groups.stream().map(group -> bizNameToName.get(group)).collect(Collectors.toList());
queryStructReq.setGroups(groups);
}
List<Filter> dimensionFilters = queryStructReq.getDimensionFilters();
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
dimensionFilters.stream().forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
}
List<Filter> metricFilters = queryStructReq.getMetricFilters();
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
metricFilters.stream().forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
}
queryStructReq.setModelName(parseInfo.getModelName());
}
}

View File

@@ -2,8 +2,8 @@ package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
@@ -44,7 +44,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
return elements.stream();
}
).collect(Collectors.toSet());
dimensions.add(DateUtils.DATE_FIELD);
dimensions.add(TimeDimensionEnum.DAY.getChName());
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sql);
@@ -52,7 +52,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
return;
}
// if only date in select not add group by.
if (selectFields.size() == 1 && selectFields.contains(DateUtils.DATE_FIELD)) {
if (selectFields.size() == 1 && selectFields.contains(TimeDimensionEnum.DAY.getChName())) {
return;
}
if (SqlParserSelectHelper.hasGroupBy(sql)) {

View File

@@ -7,8 +7,8 @@ import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.parser.llm.s2ql.S2QLDateHelper;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.StringUtil;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
@@ -71,11 +71,11 @@ public class WhereCorrector extends BaseSemanticCorrector {
private void addDateIfNotExist(SemanticCorrectInfo semanticCorrectInfo) {
String sql = semanticCorrectInfo.getSql();
List<String> whereFields = SqlParserSelectHelper.getWhereFields(sql);
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(DateUtils.DATE_FIELD)) {
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(TimeDimensionEnum.DAY.getChName())) {
String currentDate = S2QLDateHelper.getReferenceDate(semanticCorrectInfo.getParseInfo().getModelId());
if (StringUtils.isNotBlank(currentDate)) {
sql = SqlParserAddHelper.addParenthesisToWhere(sql);
sql = SqlParserAddHelper.addWhere(sql, DateUtils.DATE_FIELD, currentDate);
sql = SqlParserAddHelper.addWhere(sql, TimeDimensionEnum.DAY.getChName(), currentDate);
}
}
semanticCorrectInfo.setSql(sql);

View File

@@ -24,7 +24,7 @@ import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;

View File

@@ -18,6 +18,7 @@ import com.tencent.supersonic.chat.query.llm.s2ql.LLMResp;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.knowledge.service.SchemaService;
@@ -118,7 +119,7 @@ public class LLMRequestService {
String priorExts = getPriorExts(modelId, fieldNameList);
llmReq.setPriorExts(priorExts);
fieldNameList.add(DateUtils.DATE_FIELD);
fieldNameList.add(TimeDimensionEnum.DAY.getChName());
llmSchema.setFieldNameList(fieldNameList);
llmReq.setSchema(llmSchema);

View File

@@ -2,27 +2,27 @@ package com.tencent.supersonic.chat.parser.llm.s2ql;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
import com.tencent.supersonic.chat.api.component.SemanticCorrector;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.corrector.CorrectorService;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.llm.s2ql.S2QLQuery;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.jsqlparser.FilterExpression;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
@@ -34,6 +34,7 @@ import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
@@ -41,11 +42,15 @@ import org.springframework.util.CollectionUtils;
@Service
public class LLMResponseService {
@Autowired
private CorrectorService correctorService;
public void addParseInfo(QueryContext queryCtx, ParseResult parseResult, String sql, Double weight) {
SemanticParseInfo parseInfo = getParseInfo(queryCtx, parseResult, weight);
SemanticCorrectInfo semanticCorrectInfo = getCorrectorSql(queryCtx, parseInfo, sql);
QueryFilters queryFilters = queryCtx.getRequest().getQueryFilters();
SemanticCorrectInfo semanticCorrectInfo = correctorService.correctorSql(queryFilters, parseInfo, sql);
parseInfo.getSqlInfo().setLogicSql(semanticCorrectInfo.getSql());
@@ -64,7 +69,7 @@ public class LLMResponseService {
return new ArrayList<>();
}
return allFields.stream()
.filter(entry -> !DateUtils.DATE_FIELD.equalsIgnoreCase(entry))
.filter(entry -> !TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(entry))
.collect(Collectors.toList());
}
@@ -140,7 +145,7 @@ public class LLMResponseService {
private DateConf getDateInfo(List<FilterExpression> filterExpressions) {
List<FilterExpression> dateExpressions = filterExpressions.stream()
.filter(expression -> DateUtils.DATE_FIELD.equalsIgnoreCase(expression.getFieldName()))
.filter(expression -> TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(expression.getFieldName()))
.collect(Collectors.toList());
if (CollectionUtils.isEmpty(dateExpressions)) {
return new DateConf();
@@ -182,24 +187,6 @@ public class LLMResponseService {
return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue());
}
private SemanticCorrectInfo getCorrectorSql(QueryContext queryCtx, SemanticParseInfo parseInfo, String sql) {
SemanticCorrectInfo correctInfo = SemanticCorrectInfo.builder()
.queryFilters(queryCtx.getRequest().getQueryFilters()).sql(sql)
.parseInfo(parseInfo).build();
List<SemanticCorrector> corrections = ComponentFactory.getSqlCorrections();
corrections.forEach(correction -> {
try {
correction.correct(correctInfo);
log.info("sqlCorrection:{} sql:{}", correction.getClass().getSimpleName(), correctInfo.getSql());
} catch (Exception e) {
log.error(String.format("correct error,correctInfo:%s", correctInfo), e);
}
});
return correctInfo;
}
private SemanticParseInfo getParseInfo(QueryContext queryCtx, ParseResult parseResult, Double weight) {
if (Objects.isNull(weight)) {

View File

@@ -10,6 +10,7 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.corrector.CorrectorService;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
@@ -29,6 +30,7 @@ import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
@@ -40,6 +42,9 @@ public class MetricInterpretQuery extends PluginSemanticQuery {
public static final String QUERY_MODE = "METRIC_INTERPRET";
@Autowired
private CorrectorService correctorService;
public MetricInterpretQuery() {
QueryManager.register(this);
}
@@ -58,6 +63,9 @@ public class MetricInterpretQuery extends PluginSemanticQuery {
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
queryStructReq.setUseS2qlSwitch(optimizationConfig.isUseS2qlSwitch());
if (optimizationConfig.isUseS2qlSwitch()) {
correctorService.addS2QLAndLoginSql(queryStructReq, parseInfo);
}
QueryResultWithSchemaResp queryResultWithSchemaResp = semanticInterpreter.queryByStruct(queryStructReq, user);
String text = generateTableText(queryResultWithSchemaResp);

View File

@@ -4,6 +4,7 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.utils.ComponentFactory;
@@ -42,7 +43,7 @@ public class S2QLQuery extends PluginSemanticQuery {
long startTime = System.currentTimeMillis();
String querySql = parseInfo.getSqlInfo().getLogicSql();
QueryS2QLReq queryS2QLReq = getQueryS2QLReq(querySql);
QueryS2QLReq queryS2QLReq = QueryReqBuilder.buildS2QLReq(querySql, parseInfo.getModelId());
QueryResultWithSchemaResp queryResp = semanticInterpreter.queryByS2QL(queryS2QLReq, user);
log.info("queryByS2QL cost:{},querySql:{}", System.currentTimeMillis() - startTime, querySql);
@@ -64,22 +65,22 @@ public class S2QLQuery extends PluginSemanticQuery {
return queryResult;
}
private QueryS2QLReq getQueryS2QLReq(String sql) {
return QueryReqBuilder.buildS2QLReq(sql, parseInfo.getModelId());
}
@Override
public ExplainResp explain(User user) {
public SqlInfo explain(User user) {
SqlInfo sqlInfo = parseInfo.getSqlInfo();
ExplainSqlReq explainSqlReq = null;
try {
QueryS2QLReq queryS2QLReq = QueryReqBuilder.buildS2QLReq(sqlInfo.getLogicSql(), parseInfo.getModelId());
explainSqlReq = ExplainSqlReq.builder()
.queryTypeEnum(QueryTypeEnum.SQL)
.queryReq(getQueryS2QLReq(parseInfo.getSqlInfo().getLogicSql()))
.queryReq(queryS2QLReq)
.build();
return semanticInterpreter.explain(explainSqlReq, user);
ExplainResp explain = semanticInterpreter.explain(explainSqlReq, user);
sqlInfo.setQuerySql(explain.getSql());
} catch (Exception e) {
log.error("explain error explainSqlReq:{}", explainSqlReq, e);
}
return null;
return sqlInfo;
}
}

View File

@@ -3,7 +3,7 @@ package com.tencent.supersonic.chat.query.plugin;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import lombok.extern.slf4j.Slf4j;
@Slf4j
@@ -20,7 +20,7 @@ public abstract class PluginSemanticQuery implements SemanticQuery {
}
@Override
public ExplainResp explain(User user) {
public SqlInfo explain(User user) {
return null;
}
}

View File

@@ -14,7 +14,9 @@ import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.corrector.CorrectorService;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
@@ -27,6 +29,7 @@ import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq;
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
import com.tencent.supersonic.semantic.api.query.request.QueryS2QLReq;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import java.io.Serializable;
import java.util.ArrayList;
@@ -201,7 +204,10 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable {
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
queryStructReq.setUseS2qlSwitch(optimizationConfig.isUseS2qlSwitch());
if (optimizationConfig.isUseS2qlSwitch()) {
CorrectorService correctorService = ContextUtils.getBean(CorrectorService.class);
correctorService.addS2QLAndLoginSql(queryStructReq, parseInfo);
}
QueryResultWithSchemaResp queryResp = semanticInterpreter.queryByStruct(queryStructReq, user);
if (queryResp != null) {
@@ -221,20 +227,27 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable {
return queryResult;
}
@Override
public ExplainResp explain(User user) {
public SqlInfo explain(User user) {
SqlInfo sqlInfo = parseInfo.getSqlInfo();
ExplainSqlReq explainSqlReq = null;
try {
QueryStructReq queryStructReq = convertQueryStruct();
CorrectorService correctorService = ContextUtils.getBean(CorrectorService.class);
correctorService.addS2QLAndLoginSql(queryStructReq, parseInfo);
QueryS2QLReq queryS2QLReq = QueryReqBuilder.buildS2QLReq(sqlInfo.getLogicSql(), parseInfo.getModelId());
explainSqlReq = ExplainSqlReq.builder()
.queryTypeEnum(QueryTypeEnum.STRUCT)
.queryReq(isMultiStructQuery()
? convertQueryMultiStruct() : convertQueryStruct())
.queryTypeEnum(QueryTypeEnum.SQL)
.queryReq(queryS2QLReq)
.build();
return semanticInterpreter.explain(explainSqlReq, user);
ExplainResp explain = semanticInterpreter.explain(explainSqlReq, user);
sqlInfo.setQuerySql(explain.getSql());
} catch (Exception e) {
log.error("explain error explainSqlReq:{}", explainSqlReq, e);
}
return null;
return sqlInfo;
}
protected boolean isMultiStructQuery() {

View File

@@ -6,26 +6,26 @@ import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.ParseTimeCostDO;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.springframework.util.CollectionUtils;
public class ExplainSqlParseResponder implements ParseResponder {
public class SqlInfoParseResponder implements ParseResponder {
@Override
public void fillResponse(ParseResp parseResp, QueryContext queryContext,
List<ChatParseDO> chatParseDOS) {
List<ChatParseDO> chatParseDOS) {
QueryReq queryReq = queryContext.getRequest();
Long startTime = System.currentTimeMillis();
addExplainSql(queryReq, parseResp.getSelectedParses());
addExplainSql(queryReq, parseResp.getCandidateParses());
addSqlInfo(queryReq, parseResp.getSelectedParses());
addSqlInfo(queryReq, parseResp.getCandidateParses());
parseResp.setParseTimeCost(new ParseTimeCostDO());
parseResp.getParseTimeCost().setSqlTime(System.currentTimeMillis() - startTime);
if (!CollectionUtils.isEmpty(chatParseDOS)) {
@@ -49,26 +49,26 @@ public class ExplainSqlParseResponder implements ParseResponder {
}
}
private void addExplainSql(QueryReq queryReq, List<SemanticParseInfo> semanticParseInfos) {
private void addSqlInfo(QueryReq queryReq, List<SemanticParseInfo> semanticParseInfos) {
if (CollectionUtils.isEmpty(semanticParseInfos)) {
return;
}
semanticParseInfos.forEach(parseInfo -> {
addExplainSql(queryReq, parseInfo);
addSqlInfo(queryReq, parseInfo);
});
}
private void addExplainSql(QueryReq queryReq, SemanticParseInfo parseInfo) {
private void addSqlInfo(QueryReq queryReq, SemanticParseInfo parseInfo) {
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
if (Objects.isNull(semanticQuery)) {
return;
}
semanticQuery.setParseInfo(parseInfo);
ExplainResp explain = semanticQuery.explain(queryReq.getUser());
if (Objects.isNull(explain)) {
SqlInfo sqlInfo = semanticQuery.explain(queryReq.getUser());
if (Objects.isNull(sqlInfo)) {
return;
}
parseInfo.getSqlInfo().setQuerySql(explain.getSql());
parseInfo.setSqlInfo(sqlInfo);
}
}

View File

@@ -264,8 +264,6 @@ public class SemanticService {
QueryResultWithSchemaResp queryResultWithColumns = null;
try {
QueryStructReq queryStructReq = QueryReqBuilder.buildStructReq(semanticParseInfo);
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
queryStructReq.setUseS2qlSwitch(optimizationConfig.isUseS2qlSwitch());
queryResultWithColumns = semanticInterpreter.queryByStruct(queryStructReq, user);
} catch (Exception e) {
log.warn("setMainModel queryByStruct error, e:", e);

View File

@@ -20,6 +20,7 @@ import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
@@ -39,6 +40,8 @@ import com.tencent.supersonic.chat.utils.SolvedQueryManager;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.JsonUtil;
@@ -52,9 +55,8 @@ import com.tencent.supersonic.knowledge.dictionary.MultiCustomDictionary;
import com.tencent.supersonic.knowledge.service.SearchService;
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
import com.tencent.supersonic.knowledge.utils.NatureHelper;
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
@@ -65,8 +67,6 @@ import java.util.Objects;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.stream.Collectors;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.LongValue;
@@ -252,6 +252,7 @@ public class QueryServiceImpl implements QueryService {
queryResult.setQueryTimeCost(System.currentTimeMillis() - executeTime);
return queryResult;
}
// save time cost data
public void saveInfo(List<StatisticsDO> timeCostDOList,
String queryText, Long queryId,
@@ -329,6 +330,7 @@ public class QueryServiceImpl implements QueryService {
ChatContext context = chatService.getOrCreateContext(queryCtx.getChatId());
return context.getParseInfo();
}
//mainly used for executing after revising filters,for example:"fans_cnt>=100000"->"fans_cnt>500000",
//"style='流行'"->"style in ['流行','爱国']"
@Override
@@ -371,9 +373,9 @@ public class QueryServiceImpl implements QueryService {
log.info("correctorSql after replacing:{}", correctorSql);
parseInfo.getSqlInfo().setLogicSql(correctorSql);
semanticQuery.setParseInfo(parseInfo);
ExplainResp explain = semanticQuery.explain(user);
if (!Objects.isNull(explain)) {
parseInfo.getSqlInfo().setQuerySql(explain.getSql());
SqlInfo sqlInfo = semanticQuery.explain(user);
if (!Objects.isNull(sqlInfo)) {
parseInfo.setSqlInfo(sqlInfo);
}
}
semanticQuery.setParseInfo(parseInfo);
@@ -402,7 +404,7 @@ public class QueryServiceImpl implements QueryService {
return;
}
Map<String, String> map = new HashMap<>();
String dateField = DateUtils.DATE_FIELD;
String dateField = TimeDimensionEnum.DAY.getChName();
if (queryData.getDateInfo().getUnit() > 1) {
queryData.getDateInfo().setStartDate(DateUtils.getBeforeDate(queryData.getDateInfo().getUnit() + 1));
queryData.getDateInfo().setEndDate(DateUtils.getBeforeDate(1));
@@ -410,7 +412,7 @@ public class QueryServiceImpl implements QueryService {
// startDate equals to endDate
if (queryData.getDateInfo().getStartDate().equals(queryData.getDateInfo().getEndDate())) {
for (FilterExpression filterExpression : filterExpressionList) {
if (DateUtils.DATE_FIELD.equals(filterExpression.getFieldName())) {
if (TimeDimensionEnum.DAY.getChName().equals(filterExpression.getFieldName())) {
//sql where condition exists 'equals' operator about date,just replace
if (filterExpression.getOperator().equals(FilterOperatorEnum.EQUALS)) {
dateField = filterExpression.getFieldName();
@@ -419,9 +421,9 @@ public class QueryServiceImpl implements QueryService {
filedNameToValueMap.put(dateField, map);
} else {
// first remove,then add
removeFieldNames.add(DateUtils.DATE_FIELD);
removeFieldNames.add(TimeDimensionEnum.DAY.getChName());
EqualsTo equalsTo = new EqualsTo();
Column column = new Column(DateUtils.DATE_FIELD);
Column column = new Column(TimeDimensionEnum.DAY.getChName());
StringValue stringValue = new StringValue(queryData.getDateInfo().getStartDate());
equalsTo.setLeftExpression(column);
equalsTo.setRightExpression(stringValue);
@@ -432,7 +434,7 @@ public class QueryServiceImpl implements QueryService {
}
} else {
for (FilterExpression filterExpression : filterExpressionList) {
if (DateUtils.DATE_FIELD.equals(filterExpression.getFieldName())) {
if (TimeDimensionEnum.DAY.getChName().equals(filterExpression.getFieldName())) {
dateField = filterExpression.getFieldName();
//just replace
if (FilterOperatorEnum.GREATER_THAN_EQUALS.getValue().equals(filterExpression.getOperator())
@@ -448,7 +450,7 @@ public class QueryServiceImpl implements QueryService {
filedNameToValueMap.put(dateField, map);
// first remove,then add
if (FilterOperatorEnum.EQUALS.getValue().equals(filterExpression.getOperator())) {
removeFieldNames.add(DateUtils.DATE_FIELD);
removeFieldNames.add(TimeDimensionEnum.DAY.getChName());
GreaterThanEquals greaterThanEquals = new GreaterThanEquals();
addTimeFilters(queryData.getDateInfo().getStartDate(), greaterThanEquals, addConditions);
MinorThanEquals minorThanEquals = new MinorThanEquals();
@@ -463,7 +465,7 @@ public class QueryServiceImpl implements QueryService {
public <T extends ComparisonOperator> void addTimeFilters(String date,
T comparisonExpression,
List<Expression> addConditions) {
Column column = new Column(DateUtils.DATE_FIELD);
Column column = new Column(TimeDimensionEnum.DAY.getChName());
StringValue stringValue = new StringValue(date);
comparisonExpression.setLeftExpression(column);
comparisonExpression.setRightExpression(stringValue);
@@ -508,6 +510,7 @@ public class QueryServiceImpl implements QueryService {
}
}
}
// add in condition to sql where condition
public void addWhereInFilters(QueryFilter dslQueryFilter,
InExpression inExpression,
@@ -536,6 +539,7 @@ public class QueryServiceImpl implements QueryService {
}
});
}
// add where filter
public <T extends ComparisonOperator> void addWhereFilters(QueryFilter dslQueryFilter,
T comparisonExpression,

View File

@@ -50,6 +50,7 @@ public class DictQueryHelper {
@Value("${dimension.white.weight:60000000}")
private Long dimensionWhiteWeight;
public List<String> fetchDimValueSingle(Long modelId, DefaultMetric defaultMetricDesc, Dim4Dict dim4Dict,
User user) {
List<String> data = new ArrayList<>();

View File

@@ -11,7 +11,7 @@ import com.tencent.supersonic.common.pojo.Filter;
import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
import com.tencent.supersonic.semantic.api.query.request.QueryS2QLReq;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
@@ -25,6 +25,7 @@ import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.util.Strings;
import org.springframework.beans.BeanUtils;
import org.springframework.util.CollectionUtils;
@@ -141,8 +142,14 @@ public class QueryReqBuilder {
private static List<Aggregator> getAggregatorByMetric(AggregateTypeEnum aggregateType, SchemaElement metric) {
List<Aggregator> aggregators = new ArrayList<>();
if (metric != null) {
String agg = (aggregateType == null || aggregateType.equals(AggregateTypeEnum.NONE)) ? ""
: aggregateType.name();
String agg = "";
if (Objects.isNull(aggregateType) || aggregateType.equals(AggregateTypeEnum.NONE)) {
if (StringUtils.isNotBlank(metric.getDefaultAgg())) {
agg = metric.getDefaultAgg();
}
} else {
agg = aggregateType.name();
}
aggregators.add(new Aggregator(metric.getBizName(), AggOperatorEnum.of(agg)));
}
return aggregators;

View File

@@ -30,6 +30,7 @@ class QueryReqBuilderTest {
QueryStructReq queryStructReq = new QueryStructReq();
queryStructReq.setModelId(1L);
queryStructReq.setNativeQuery(false);
queryStructReq.setModelName("内容库");
Aggregator aggregator = new Aggregator();
aggregator.setFunc(AggOperatorEnum.UNKNOWN);
@@ -51,13 +52,13 @@ class QueryReqBuilderTest {
QueryS2QLReq queryS2QLReq = queryStructReq.convert(queryStructReq);
Assert.assertEquals(
"SELECT department, SUM(pv) FROM t_1 WHERE (sys_imp_date IN ('2023-08-01')) "
"SELECT department, SUM(pv) FROM 内容库 WHERE (sys_imp_date IN ('2023-08-01')) "
+ "GROUP BY department ORDER BY uv LIMIT 2000", queryS2QLReq.getSql());
queryStructReq.setNativeQuery(true);
queryS2QLReq = queryStructReq.convert(queryStructReq);
Assert.assertEquals(
"SELECT department, pv FROM t_1 WHERE (sys_imp_date IN ('2023-08-01')) "
"SELECT department, pv FROM 内容库 WHERE (sys_imp_date IN ('2023-08-01')) "
+ "ORDER BY uv LIMIT 2000",
queryS2QLReq.getSql());