(improvement)(chat) Overall code optimization for the corrector (#345)

This commit is contained in:
lexluo09
2023-11-09 16:03:05 +08:00
committed by GitHub
parent 608a4f7a2f
commit 16c3de44e4
25 changed files with 507 additions and 542 deletions

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.chat.api.component;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import net.sf.jsqlparser.JSQLParserException;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
/**
* A semantic corrector checks validity of extracted semantic information and
@@ -9,5 +9,5 @@ import net.sf.jsqlparser.JSQLParserException;
*/
public interface SemanticCorrector {
void correct(SemanticCorrectInfo semanticCorrectInfo) throws JSQLParserException;
void correct(QueryReq queryReq, SemanticParseInfo semanticParseInfo);
}

View File

@@ -3,7 +3,6 @@ package com.tencent.supersonic.chat.api.component;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import org.apache.calcite.sql.parser.SqlParseException;
/**
@@ -15,9 +14,13 @@ public interface SemanticQuery {
QueryResult execute(User user) throws SqlParseException;
SqlInfo explain(User user);
void initS2Sql(User user);
String explain(User user);
SemanticParseInfo getParseInfo();
void updateParseInfo();
void setParseInfo(SemanticParseInfo parseInfo);
}

View File

@@ -42,7 +42,7 @@ public class OptimizationConfig {
@Value("${user.s2ql.switch:false}")
private boolean useS2qlSwitch;
@Value("${embedding.mapper.word.min:3}")
@Value("${embedding.mapper.word.min:4}")
private int embeddingMapperWordMin;
@Value("${embedding.mapper.word.max:5}")
@@ -57,6 +57,6 @@ public class OptimizationConfig {
@Value("${embedding.mapper.round.number:10}")
private int embeddingMapperRoundNumber;
@Value("${embedding.mapper.distance.threshold:0.52}")
@Value("${embedding.mapper.distance.threshold:0.58}")
private Double embeddingMapperDistanceThreshold;
}

View File

@@ -2,8 +2,9 @@ package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.component.SemanticCorrector;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
@@ -18,15 +19,27 @@ import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
@Slf4j
public abstract class BaseSemanticCorrector implements SemanticCorrector {
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
semanticCorrectInfo.setPreSql(semanticCorrectInfo.getSql());
public void correct(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
try {
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getLogicSql())) {
return;
}
work(queryReq, semanticParseInfo);
log.info("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(), semanticParseInfo.getSqlInfo());
} catch (Exception e) {
log.error(String.format("correct error,sqlInfo:%s", semanticParseInfo.getSqlInfo()), e);
}
}
public abstract void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo);
protected Map<String, String> getFieldNameMap(Long modelId) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
@@ -58,10 +71,10 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
return result;
}
protected void addFieldsToSelect(SemanticCorrectInfo semanticCorrectInfo, String sql) {
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(sql));
Set<String> needAddFields = new HashSet<>(SqlParserSelectHelper.getGroupByFields(sql));
needAddFields.addAll(SqlParserSelectHelper.getOrderByFields(sql));
protected void addFieldsToSelect(SemanticParseInfo semanticParseInfo, String logicSql) {
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(logicSql));
Set<String> needAddFields = new HashSet<>(SqlParserSelectHelper.getGroupByFields(logicSql));
needAddFields.addAll(SqlParserSelectHelper.getOrderByFields(logicSql));
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(needAddFields)) {
return;
@@ -69,14 +82,14 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
needAddFields.removeAll(selectFields);
needAddFields.remove(TimeDimensionEnum.DAY.getChName());
String replaceFields = SqlParserAddHelper.addFieldsToSelect(sql, new ArrayList<>(needAddFields));
semanticCorrectInfo.setSql(replaceFields);
String replaceFields = SqlParserAddHelper.addFieldsToSelect(logicSql, new ArrayList<>(needAddFields));
semanticParseInfo.getSqlInfo().setLogicSql(replaceFields);
}
protected void addAggregateToMetric(SemanticCorrectInfo semanticCorrectInfo) {
protected void addAggregateToMetric(SemanticParseInfo semanticParseInfo) {
//add aggregate to all metric
String sql = semanticCorrectInfo.getSql();
Long modelId = semanticCorrectInfo.getParseInfo().getModel().getModel();
String logicSql = semanticParseInfo.getSqlInfo().getLogicSql();
Long modelId = semanticParseInfo.getModel().getModel();
List<SchemaElement> metrics = getMetricElements(modelId);
@@ -91,9 +104,8 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
if (CollectionUtils.isEmpty(metricToAggregate)) {
return;
}
String aggregateSql = SqlParserAddHelper.addAggregateToField(sql, metricToAggregate);
semanticCorrectInfo.setSql(aggregateSql);
String aggregateSql = SqlParserAddHelper.addAggregateToField(logicSql, metricToAggregate);
semanticParseInfo.getSqlInfo().setLogicSql(aggregateSql);
}
protected List<SchemaElement> getMetricElements(Long modelId) {

View File

@@ -1,13 +0,0 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
public interface CorrectorService {
SemanticCorrectInfo correctorSql(QueryFilters queryFilters, SemanticParseInfo parseInfo, String sql);
void addS2QLAndLoginSql(QueryStructReq queryStructReq, SemanticParseInfo parseInfo);
}

View File

@@ -1,95 +0,0 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.component.SemanticCorrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.Aggregator;
import com.tencent.supersonic.common.pojo.Filter;
import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.semantic.api.query.request.QueryS2QLReq;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@Slf4j
@Service
public class CorrectorServiceImpl implements CorrectorService {
@Autowired
private SchemaService schemaService;
public SemanticCorrectInfo correctorSql(QueryFilters queryFilters, SemanticParseInfo parseInfo, String sql) {
SemanticCorrectInfo correctInfo = SemanticCorrectInfo.builder()
.queryFilters(queryFilters).sql(sql)
.parseInfo(parseInfo).build();
List<SemanticCorrector> corrections = ComponentFactory.getSqlCorrections();
corrections.forEach(correction -> {
try {
correction.correct(correctInfo);
log.info("sqlCorrection:{} sql:{}", correction.getClass().getSimpleName(), correctInfo.getSql());
} catch (Exception e) {
log.error(String.format("correct error,correctInfo:%s", correctInfo), e);
}
});
return correctInfo;
}
public void addS2QLAndLoginSql(QueryStructReq queryStructReq, SemanticParseInfo parseInfo) {
convertBizNameToName(queryStructReq, parseInfo);
QueryS2QLReq queryS2QLReq = queryStructReq.convert(queryStructReq);
parseInfo.getSqlInfo().setS2QL(queryS2QLReq.getSql());
queryStructReq.setS2QL(queryS2QLReq.getSql());
SemanticCorrectInfo semanticCorrectInfo = correctorSql(new QueryFilters(), parseInfo,
queryS2QLReq.getSql());
parseInfo.getSqlInfo().setLogicSql(semanticCorrectInfo.getSql());
queryStructReq.setLogicSql(semanticCorrectInfo.getSql());
}
private void convertBizNameToName(QueryStructReq queryStructReq, SemanticParseInfo parseInfo) {
Map<String, String> bizNameToName = schemaService.getSemanticSchema()
.getBizNameToName(queryStructReq.getModelId());
List<Order> orders = queryStructReq.getOrders();
if (CollectionUtils.isNotEmpty(orders)) {
for (Order order : orders) {
order.setColumn(bizNameToName.get(order.getColumn()));
}
}
List<Aggregator> aggregators = queryStructReq.getAggregators();
if (CollectionUtils.isNotEmpty(aggregators)) {
for (Aggregator aggregator : aggregators) {
aggregator.setColumn(bizNameToName.get(aggregator.getColumn()));
}
}
List<String> groups = queryStructReq.getGroups();
if (CollectionUtils.isNotEmpty(groups)) {
groups = groups.stream().map(group -> bizNameToName.get(group)).collect(Collectors.toList());
queryStructReq.setGroups(groups);
}
List<Filter> dimensionFilters = queryStructReq.getDimensionFilters();
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
dimensionFilters.stream().forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
}
List<Filter> metricFilters = queryStructReq.getMetricFilters();
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
metricFilters.stream().forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
}
queryStructReq.setModelName(parseInfo.getModelName());
}
}

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
@@ -12,17 +13,16 @@ import net.sf.jsqlparser.expression.Expression;
public class GlobalAfterCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
super.correct(semanticCorrectInfo);
String sql = semanticCorrectInfo.getSql();
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(sql)) {
String logicSql = semanticParseInfo.getSqlInfo().getLogicSql();
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(logicSql)) {
return;
}
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(sql);
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(logicSql);
if (Objects.nonNull(havingExpression)) {
String replaceSql = SqlParserAddHelper.addFunctionToSelect(sql, havingExpression);
semanticCorrectInfo.setSql(replaceSql);
String replaceSql = SqlParserAddHelper.addFunctionToSelect(logicSql, havingExpression);
semanticParseInfo.getSqlInfo().setLogicSql(replaceSql);
}
return;
}

View File

@@ -1,6 +1,8 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.parser.llm.s2ql.ParseResult;
import com.tencent.supersonic.chat.query.llm.s2ql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2ql.LLMReq.ElementValue;
@@ -19,35 +21,32 @@ import org.springframework.util.CollectionUtils;
public class GlobalBeforeCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
super.correct(semanticCorrectInfo);
replaceAlias(semanticParseInfo);
replaceAlias(semanticCorrectInfo);
updateFieldNameByLinkingValue(semanticParseInfo);
updateFieldNameByLinkingValue(semanticCorrectInfo);
updateFieldValueByLinkingValue(semanticParseInfo);
updateFieldValueByLinkingValue(semanticCorrectInfo);
correctFieldName(semanticCorrectInfo);
correctFieldName(semanticParseInfo);
}
private void replaceAlias(SemanticCorrectInfo semanticCorrectInfo) {
String replaceAlias = SqlParserReplaceHelper.replaceAlias(semanticCorrectInfo.getSql());
semanticCorrectInfo.setSql(replaceAlias);
private void replaceAlias(SemanticParseInfo semanticParseInfo) {
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String replaceAlias = SqlParserReplaceHelper.replaceAlias(sqlInfo.getLogicSql());
sqlInfo.setLogicSql(replaceAlias);
}
private void correctFieldName(SemanticCorrectInfo semanticCorrectInfo) {
Map<String, String> fieldNameMap = getFieldNameMap(semanticCorrectInfo.getParseInfo().getModelId());
String sql = SqlParserReplaceHelper.replaceFields(semanticCorrectInfo.getSql(), fieldNameMap);
semanticCorrectInfo.setSql(sql);
private void correctFieldName(SemanticParseInfo semanticParseInfo) {
Map<String, String> fieldNameMap = getFieldNameMap(semanticParseInfo.getModelId());
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String sql = SqlParserReplaceHelper.replaceFields(sqlInfo.getLogicSql(), fieldNameMap);
sqlInfo.setLogicSql(sql);
}
private void updateFieldNameByLinkingValue(SemanticCorrectInfo semanticCorrectInfo) {
List<ElementValue> linking = getLinkingValues(semanticCorrectInfo);
private void updateFieldNameByLinkingValue(SemanticParseInfo semanticParseInfo) {
List<ElementValue> linking = getLinkingValues(semanticParseInfo);
if (CollectionUtils.isEmpty(linking)) {
return;
}
@@ -56,13 +55,14 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector {
Collectors.groupingBy(ElementValue::getFieldValue,
Collectors.mapping(ElementValue::getFieldName, Collectors.toSet())));
String sql = SqlParserReplaceHelper.replaceFieldNameByValue(semanticCorrectInfo.getSql(),
fieldValueToFieldNames);
semanticCorrectInfo.setSql(sql);
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String sql = SqlParserReplaceHelper.replaceFieldNameByValue(sqlInfo.getLogicSql(), fieldValueToFieldNames);
sqlInfo.setLogicSql(sql);
}
private List<ElementValue> getLinkingValues(SemanticCorrectInfo semanticCorrectInfo) {
Object context = semanticCorrectInfo.getParseInfo().getProperties().get(Constants.CONTEXT);
private List<ElementValue> getLinkingValues(SemanticParseInfo semanticParseInfo) {
Object context = semanticParseInfo.getProperties().get(Constants.CONTEXT);
if (Objects.isNull(context)) {
return null;
}
@@ -76,8 +76,8 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector {
}
private void updateFieldValueByLinkingValue(SemanticCorrectInfo semanticCorrectInfo) {
List<ElementValue> linking = getLinkingValues(semanticCorrectInfo);
private void updateFieldValueByLinkingValue(SemanticParseInfo semanticParseInfo) {
List<ElementValue> linking = getLinkingValues(semanticParseInfo);
if (CollectionUtils.isEmpty(linking)) {
return;
}
@@ -90,7 +90,8 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector {
(existingValue, newValue) -> newValue)
)));
String sql = SqlParserReplaceHelper.replaceValue(semanticCorrectInfo.getSql(), filedNameToValueMap, false);
semanticCorrectInfo.setSql(sql);
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String sql = SqlParserReplaceHelper.replaceValue(sqlInfo.getLogicSql(), filedNameToValueMap, false);
sqlInfo.setLogicSql(sql);
}
}

View File

@@ -1,7 +1,9 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
@@ -18,19 +20,18 @@ import org.springframework.util.CollectionUtils;
public class GroupByCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
super.correct(semanticCorrectInfo);
addGroupByFields(semanticCorrectInfo);
addGroupByFields(semanticParseInfo);
}
private void addGroupByFields(SemanticCorrectInfo semanticCorrectInfo) {
Long modelId = semanticCorrectInfo.getParseInfo().getModel().getModel();
private void addGroupByFields(SemanticParseInfo semanticParseInfo) {
Long modelId = semanticParseInfo.getModel().getModel();
//add dimension group by
String sql = semanticCorrectInfo.getSql();
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String logicSql = sqlInfo.getLogicSql();
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
//add alias field name
Set<String> dimensions = semanticSchema.getDimensions(modelId).stream()
@@ -46,7 +47,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
).collect(Collectors.toSet());
dimensions.add(TimeDimensionEnum.DAY.getChName());
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sql);
List<String> selectFields = SqlParserSelectHelper.getSelectFields(logicSql);
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) {
return;
@@ -55,12 +56,12 @@ public class GroupByCorrector extends BaseSemanticCorrector {
if (selectFields.size() == 1 && selectFields.contains(TimeDimensionEnum.DAY.getChName())) {
return;
}
if (SqlParserSelectHelper.hasGroupBy(sql)) {
log.info("not add group by ,exist group by in sql:{}", sql);
if (SqlParserSelectHelper.hasGroupBy(logicSql)) {
log.info("not add group by ,exist group by in logicSql:{}", logicSql);
return;
}
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(sql);
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(logicSql);
Set<String> groupByFields = selectFields.stream()
.filter(field -> dimensions.contains(field))
.filter(field -> {
@@ -70,16 +71,17 @@ public class GroupByCorrector extends BaseSemanticCorrector {
return true;
})
.collect(Collectors.toSet());
semanticCorrectInfo.setSql(SqlParserAddHelper.addGroupBy(sql, groupByFields));
semanticParseInfo.getSqlInfo().setLogicSql(SqlParserAddHelper.addGroupBy(logicSql, groupByFields));
addAggregate(semanticCorrectInfo);
addAggregate(semanticParseInfo);
}
private void addAggregate(SemanticCorrectInfo semanticCorrectInfo) {
List<String> sqlGroupByFields = SqlParserSelectHelper.getGroupByFields(semanticCorrectInfo.getSql());
private void addAggregate(SemanticParseInfo semanticParseInfo) {
List<String> sqlGroupByFields = SqlParserSelectHelper.getGroupByFields(
semanticParseInfo.getSqlInfo().getLogicSql());
if (CollectionUtils.isEmpty(sqlGroupByFields)) {
return;
}
addAggregateToMetric(semanticCorrectInfo);
addAggregateToMetric(semanticParseInfo);
}
}

View File

@@ -1,7 +1,8 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
@@ -14,13 +15,10 @@ import org.springframework.util.CollectionUtils;
public class HavingCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
super.correct(semanticCorrectInfo);
public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
//add aggregate to all metric
semanticCorrectInfo.setPreSql(semanticCorrectInfo.getSql());
Long modelId = semanticCorrectInfo.getParseInfo().getModel().getModel();
Long modelId = semanticParseInfo.getModel().getModel();
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
@@ -30,8 +28,8 @@ public class HavingCorrector extends BaseSemanticCorrector {
if (CollectionUtils.isEmpty(metrics)) {
return;
}
String havingSql = SqlParserAddHelper.addHaving(semanticCorrectInfo.getSql(), metrics);
semanticCorrectInfo.setSql(havingSql);
String havingSql = SqlParserAddHelper.addHaving(semanticParseInfo.getSqlInfo().getLogicSql(), metrics);
semanticParseInfo.getSqlInfo().setLogicSql(havingSql);
}
}

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
@@ -10,17 +11,16 @@ import org.springframework.util.CollectionUtils;
public class SelectCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
super.correct(semanticCorrectInfo);
String sql = semanticCorrectInfo.getSql();
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(sql);
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sql);
public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
String logicSql = semanticParseInfo.getSqlInfo().getLogicSql();
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(logicSql);
List<String> selectFields = SqlParserSelectHelper.getSelectFields(logicSql);
// If the number of aggregated fields is equal to the number of queried fields, do not add fields to select.
if (!CollectionUtils.isEmpty(aggregateFields)
&& !CollectionUtils.isEmpty(selectFields)
&& aggregateFields.size() == selectFields.size()) {
return;
}
addFieldsToSelect(semanticCorrectInfo, sql);
addFieldsToSelect(semanticParseInfo, logicSql);
}
}

View File

@@ -2,9 +2,10 @@ package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.parser.llm.s2ql.S2QLDateHelper;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
@@ -31,54 +32,52 @@ import org.springframework.util.CollectionUtils;
public class WhereCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
super.correct(semanticCorrectInfo);
addDateIfNotExist(semanticParseInfo);
addDateIfNotExist(semanticCorrectInfo);
parserDateDiffFunction(semanticParseInfo);
parserDateDiffFunction(semanticCorrectInfo);
addQueryFilter(queryReq, semanticParseInfo);
addQueryFilter(semanticCorrectInfo);
updateFieldValueByTechName(semanticCorrectInfo);
updateFieldValueByTechName(semanticParseInfo);
}
private void addQueryFilter(SemanticCorrectInfo semanticCorrectInfo) {
String queryFilter = getQueryFilter(semanticCorrectInfo.getQueryFilters());
private void addQueryFilter(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
String queryFilter = getQueryFilter(queryReq.getQueryFilters());
String preSql = semanticCorrectInfo.getSql();
String logicSql = semanticParseInfo.getSqlInfo().getLogicSql();
if (StringUtils.isNotEmpty(queryFilter)) {
log.info("add queryFilter to preSql :{}", queryFilter);
log.info("add queryFilter to logicSql :{}", queryFilter);
Expression expression = null;
try {
expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
} catch (JSQLParserException e) {
log.error("parseCondExpression", e);
}
String sql = SqlParserAddHelper.addWhere(preSql, expression);
semanticCorrectInfo.setSql(sql);
logicSql = SqlParserAddHelper.addWhere(logicSql, expression);
semanticParseInfo.getSqlInfo().setLogicSql(logicSql);
}
}
private void parserDateDiffFunction(SemanticCorrectInfo semanticCorrectInfo) {
String sql = semanticCorrectInfo.getSql();
sql = SqlParserReplaceHelper.replaceFunction(sql);
semanticCorrectInfo.setSql(sql);
private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) {
String logicSql = semanticParseInfo.getSqlInfo().getLogicSql();
logicSql = SqlParserReplaceHelper.replaceFunction(logicSql);
semanticParseInfo.getSqlInfo().setLogicSql(logicSql);
}
private void addDateIfNotExist(SemanticCorrectInfo semanticCorrectInfo) {
String sql = semanticCorrectInfo.getSql();
List<String> whereFields = SqlParserSelectHelper.getWhereFields(sql);
private void addDateIfNotExist(SemanticParseInfo semanticParseInfo) {
String logicSql = semanticParseInfo.getSqlInfo().getLogicSql();
List<String> whereFields = SqlParserSelectHelper.getWhereFields(logicSql);
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(TimeDimensionEnum.DAY.getChName())) {
String currentDate = S2QLDateHelper.getReferenceDate(semanticCorrectInfo.getParseInfo().getModelId());
String currentDate = S2QLDateHelper.getReferenceDate(semanticParseInfo.getModelId());
if (StringUtils.isNotBlank(currentDate)) {
sql = SqlParserAddHelper.addParenthesisToWhere(sql);
sql = SqlParserAddHelper.addWhere(sql, TimeDimensionEnum.DAY.getChName(), currentDate);
logicSql = SqlParserAddHelper.addParenthesisToWhere(logicSql);
logicSql = SqlParserAddHelper.addWhere(logicSql, TimeDimensionEnum.DAY.getChName(), currentDate);
}
}
semanticCorrectInfo.setSql(sql);
semanticParseInfo.getSqlInfo().setLogicSql(logicSql);
}
private String getQueryFilter(QueryFilters queryFilters) {
@@ -95,9 +94,9 @@ public class WhereCorrector extends BaseSemanticCorrector {
.collect(Collectors.joining(Constants.AND_UPPER));
}
private void updateFieldValueByTechName(SemanticCorrectInfo semanticCorrectInfo) {
private void updateFieldValueByTechName(SemanticParseInfo semanticParseInfo) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
Long modelId = semanticCorrectInfo.getParseInfo().getModel().getId();
Long modelId = semanticParseInfo.getModel().getId();
List<SchemaElement> dimensions = semanticSchema.getDimensions().stream()
.filter(schemaElement -> modelId.equals(schemaElement.getModel()))
.collect(Collectors.toList());
@@ -107,8 +106,9 @@ public class WhereCorrector extends BaseSemanticCorrector {
}
Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
String sql = SqlParserReplaceHelper.replaceValue(semanticCorrectInfo.getSql(), aliasAndBizNameToTechName);
semanticCorrectInfo.setSql(sql);
String logicSql = SqlParserReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getLogicSql(),
aliasAndBizNameToTechName);
semanticParseInfo.getSqlInfo().setLogicSql(logicSql);
return;
}

View File

@@ -29,7 +29,7 @@ public abstract class BaseMapper implements SchemaMapper {
String simpleName = this.getClass().getSimpleName();
long startTime = System.currentTimeMillis();
log.info("before {},mapInfo:{}", simpleName, queryContext.getMapInfo());
log.info("before {},mapInfo:{}", simpleName, queryContext.getMapInfo().getModelElementMatches());
try {
work(queryContext);
@@ -38,7 +38,7 @@ public abstract class BaseMapper implements SchemaMapper {
}
long cost = System.currentTimeMillis() - startTime;
log.info("after {},cost:{},mapInfo:{}", simpleName, cost, queryContext.getMapInfo());
log.info("after {},cost:{},mapInfo:{}", simpleName, cost, queryContext.getMapInfo().getModelElementMatches());
}
public abstract void work(QueryContext queryContext);

View File

@@ -20,7 +20,7 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
/**
* match strategy implement
* Base Match Strategy
*/
@Service
@Slf4j
@@ -36,7 +36,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
return null;
}
log.debug("retryCount:{},terms:{},,detectModelIds:{}", terms, detectModelIds);
log.debug("terms:{},,detectModelIds:{}", terms, detectModelIds);
List<T> detects = detect(queryContext, terms, detectModelIds);
Map<MatchText, List<T>> result = new HashMap<>();
@@ -143,7 +143,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
return;
}
for (Term term : terms) {
log.info("word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency());
log.debug("word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency());
}
}

View File

@@ -1,194 +1,26 @@
package com.tencent.supersonic.chat.parser.llm.s2ql;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.corrector.CorrectorService;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.llm.s2ql.S2QLQuery;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.FilterExpression;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
@Slf4j
@Service
public class LLMResponseService {
@Autowired
private CorrectorService correctorService;
public void addParseInfo(QueryContext queryCtx, ParseResult parseResult, String sql, Double weight) {
SemanticParseInfo parseInfo = getParseInfo(queryCtx, parseResult, weight);
QueryFilters queryFilters = queryCtx.getRequest().getQueryFilters();
SemanticCorrectInfo semanticCorrectInfo = correctorService.correctorSql(queryFilters, parseInfo, sql);
parseInfo.getSqlInfo().setLogicSql(semanticCorrectInfo.getSql());
updateParseInfo(semanticCorrectInfo, parseResult.getModelId(), parseInfo);
}
private Set<SchemaElement> getElements(Long modelId, List<String> allFields, List<SchemaElement> elements) {
return elements.stream()
.filter(schemaElement -> modelId.equals(schemaElement.getModel())
&& allFields.contains(schemaElement.getName())
).collect(Collectors.toSet());
}
private List<String> getFieldsExceptDate(List<String> allFields) {
if (CollectionUtils.isEmpty(allFields)) {
return new ArrayList<>();
}
return allFields.stream()
.filter(entry -> !TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(entry))
.collect(Collectors.toList());
}
public void updateParseInfo(SemanticCorrectInfo semanticCorrectInfo, Long modelId, SemanticParseInfo parseInfo) {
String correctorSql = semanticCorrectInfo.getSql();
parseInfo.getSqlInfo().setLogicSql(correctorSql);
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctorSql);
//set dataInfo
try {
if (!CollectionUtils.isEmpty(expressions)) {
DateConf dateInfo = getDateInfo(expressions);
parseInfo.setDateInfo(dateInfo);
}
} catch (Exception e) {
log.error("set dateInfo error :", e);
}
//set filter
try {
Map<String, SchemaElement> fieldNameToElement = getNameToElement(modelId);
List<QueryFilter> result = getDimensionFilter(fieldNameToElement, expressions);
parseInfo.getDimensionFilters().addAll(result);
} catch (Exception e) {
log.error("set dimensionFilter error :", e);
}
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
if (Objects.isNull(semanticSchema)) {
return;
}
List<String> allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(semanticCorrectInfo.getSql()));
Set<SchemaElement> metrics = getElements(modelId, allFields, semanticSchema.getMetrics());
parseInfo.setMetrics(metrics);
if (SqlParserSelectFunctionHelper.hasAggregateFunction(semanticCorrectInfo.getSql())) {
parseInfo.setNativeQuery(false);
List<String> groupByFields = SqlParserSelectHelper.getGroupByFields(semanticCorrectInfo.getSql());
List<String> groupByDimensions = getFieldsExceptDate(groupByFields);
parseInfo.setDimensions(getElements(modelId, groupByDimensions, semanticSchema.getDimensions()));
} else {
parseInfo.setNativeQuery(true);
List<String> selectFields = SqlParserSelectHelper.getSelectFields(semanticCorrectInfo.getSql());
List<String> selectDimensions = getFieldsExceptDate(selectFields);
parseInfo.setDimensions(getElements(modelId, selectDimensions, semanticSchema.getDimensions()));
}
}
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> fieldNameToElement,
List<FilterExpression> filterExpressions) {
List<QueryFilter> result = Lists.newArrayList();
for (FilterExpression expression : filterExpressions) {
QueryFilter dimensionFilter = new QueryFilter();
dimensionFilter.setValue(expression.getFieldValue());
SchemaElement schemaElement = fieldNameToElement.get(expression.getFieldName());
if (Objects.isNull(schemaElement)) {
continue;
}
dimensionFilter.setName(schemaElement.getName());
dimensionFilter.setBizName(schemaElement.getBizName());
dimensionFilter.setElementID(schemaElement.getId());
FilterOperatorEnum operatorEnum = FilterOperatorEnum.getSqlOperator(expression.getOperator());
dimensionFilter.setOperator(operatorEnum);
dimensionFilter.setFunction(expression.getFunction());
result.add(dimensionFilter);
}
return result;
}
private DateConf getDateInfo(List<FilterExpression> filterExpressions) {
List<FilterExpression> dateExpressions = filterExpressions.stream()
.filter(expression -> TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(expression.getFieldName()))
.collect(Collectors.toList());
if (CollectionUtils.isEmpty(dateExpressions)) {
return new DateConf();
}
DateConf dateInfo = new DateConf();
dateInfo.setDateMode(DateMode.BETWEEN);
FilterExpression firstExpression = dateExpressions.get(0);
FilterOperatorEnum firstOperator = FilterOperatorEnum.getSqlOperator(firstExpression.getOperator());
if (FilterOperatorEnum.EQUALS.equals(firstOperator) && Objects.nonNull(firstExpression.getFieldValue())) {
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
dateInfo.setDateMode(DateMode.BETWEEN);
return dateInfo;
}
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.GREATER_THAN,
FilterOperatorEnum.GREATER_THAN_EQUALS)) {
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
if (hasSecondDate(dateExpressions)) {
dateInfo.setEndDate(dateExpressions.get(1).getFieldValue().toString());
}
}
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.MINOR_THAN,
FilterOperatorEnum.MINOR_THAN_EQUALS)) {
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
if (hasSecondDate(dateExpressions)) {
dateInfo.setStartDate(dateExpressions.get(1).getFieldValue().toString());
}
}
return dateInfo;
}
private boolean containOperators(FilterExpression expression, FilterOperatorEnum firstOperator,
FilterOperatorEnum... operatorEnums) {
return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(expression.getFieldValue()));
}
private boolean hasSecondDate(List<FilterExpression> dateExpressions) {
return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue());
}
private SemanticParseInfo getParseInfo(QueryContext queryCtx, ParseResult parseResult, Double weight) {
public SemanticParseInfo addParseInfo(QueryContext queryCtx, ParseResult parseResult, String s2ql, Double weight) {
if (Objects.isNull(weight)) {
weight = 0D;
}
@@ -206,7 +38,7 @@ public class LLMResponseService {
parseInfo.setProperties(properties);
parseInfo.setScore(queryCtx.getRequest().getQueryText().length() * (1 + weight));
parseInfo.setQueryMode(semanticQuery.getQueryMode());
parseInfo.getSqlInfo().setS2QL(parseResult.getLlmResp().getSqlOutput());
parseInfo.getSqlInfo().setS2QL(s2ql);
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
@@ -219,30 +51,4 @@ public class LLMResponseService {
queryCtx.getCandidateQueries().add(semanticQuery);
return parseInfo;
}
protected Map<String, SchemaElement> getNameToElement(Long modelId) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
List<SchemaElement> dimensions = semanticSchema.getDimensions();
List<SchemaElement> metrics = semanticSchema.getMetrics();
List<SchemaElement> allElements = Lists.newArrayList();
allElements.addAll(dimensions);
allElements.addAll(metrics);
//support alias
return allElements.stream()
.filter(schemaElement -> schemaElement.getModel().equals(modelId))
.flatMap(schemaElement -> {
Set<Pair<String, SchemaElement>> result = new HashSet<>();
result.add(Pair.of(schemaElement.getName(), schemaElement));
List<String> aliasList = schemaElement.getAlias();
if (!CollectionUtils.isEmpty(aliasList)) {
for (String alias : aliasList) {
result.add(Pair.of(alias, schemaElement));
}
}
return result.stream();
})
.collect(Collectors.toMap(pair -> pair.getLeft(), pair -> pair.getRight(), (value1, value2) -> value2));
}
}

View File

@@ -0,0 +1,291 @@
package com.tencent.supersonic.chat.query;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
import com.tencent.supersonic.common.pojo.Aggregator;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
import com.tencent.supersonic.common.pojo.Filter;
import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.FilterExpression;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.semantic.api.model.enums.QueryTypeEnum;
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq;
import com.tencent.supersonic.semantic.api.query.request.QueryS2QLReq;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.ToString;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
@Slf4j
@ToString
public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
protected SemanticParseInfo parseInfo = new SemanticParseInfo();
protected SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
@Override
public String explain(User user) {
ExplainSqlReq explainSqlReq = null;
SqlInfo sqlInfo = parseInfo.getSqlInfo();
try {
QueryS2QLReq queryS2QLReq = QueryReqBuilder.buildS2QLReq(sqlInfo.getLogicSql(), parseInfo.getModelId());
explainSqlReq = ExplainSqlReq.builder()
.queryTypeEnum(QueryTypeEnum.SQL)
.queryReq(queryS2QLReq)
.build();
ExplainResp explain = semanticInterpreter.explain(explainSqlReq, user);
if (Objects.nonNull(explain)) {
return explain.getSql();
}
return explain.getSql();
} catch (Exception e) {
log.error("explain error explainSqlReq:{}", explainSqlReq, e);
}
return null;
}
@Override
public SemanticParseInfo getParseInfo() {
return parseInfo;
}
@Override
public void setParseInfo(SemanticParseInfo parseInfo) {
this.parseInfo = parseInfo;
}
protected QueryStructReq convertQueryStruct() {
return QueryReqBuilder.buildStructReq(parseInfo);
}
public void updateParseInfo() {
SqlInfo sqlInfo = parseInfo.getSqlInfo();
String logicSql = sqlInfo.getLogicSql();
if (StringUtils.isBlank(logicSql)) {
return;
}
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(logicSql);
//set dataInfo
try {
if (!org.springframework.util.CollectionUtils.isEmpty(expressions)) {
DateConf dateInfo = getDateInfo(expressions);
parseInfo.setDateInfo(dateInfo);
}
} catch (Exception e) {
log.error("set dateInfo error :", e);
}
//set filter
try {
Map<String, SchemaElement> fieldNameToElement = getNameToElement(parseInfo.getModelId());
List<QueryFilter> result = getDimensionFilter(fieldNameToElement, expressions);
parseInfo.getDimensionFilters().addAll(result);
} catch (Exception e) {
log.error("set dimensionFilter error :", e);
}
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
if (Objects.isNull(semanticSchema)) {
return;
}
List<String> allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getLogicSql()));
Set<SchemaElement> metrics = getElements(parseInfo.getModelId(), allFields, semanticSchema.getMetrics());
parseInfo.setMetrics(metrics);
if (SqlParserSelectFunctionHelper.hasAggregateFunction(sqlInfo.getLogicSql())) {
parseInfo.setNativeQuery(false);
List<String> groupByFields = SqlParserSelectHelper.getGroupByFields(sqlInfo.getLogicSql());
List<String> groupByDimensions = getFieldsExceptDate(groupByFields);
parseInfo.setDimensions(
getElements(parseInfo.getModelId(), groupByDimensions, semanticSchema.getDimensions()));
} else {
parseInfo.setNativeQuery(true);
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getLogicSql());
List<String> selectDimensions = getFieldsExceptDate(selectFields);
parseInfo.setDimensions(
getElements(parseInfo.getModelId(), selectDimensions, semanticSchema.getDimensions()));
}
}
private Set<SchemaElement> getElements(Long modelId, List<String> allFields, List<SchemaElement> elements) {
return elements.stream()
.filter(schemaElement -> modelId.equals(schemaElement.getModel())
&& allFields.contains(schemaElement.getName())
).collect(Collectors.toSet());
}
private List<String> getFieldsExceptDate(List<String> allFields) {
if (org.springframework.util.CollectionUtils.isEmpty(allFields)) {
return new ArrayList<>();
}
return allFields.stream()
.filter(entry -> !TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(entry))
.collect(Collectors.toList());
}
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> fieldNameToElement,
List<FilterExpression> filterExpressions) {
List<QueryFilter> result = Lists.newArrayList();
for (FilterExpression expression : filterExpressions) {
QueryFilter dimensionFilter = new QueryFilter();
dimensionFilter.setValue(expression.getFieldValue());
SchemaElement schemaElement = fieldNameToElement.get(expression.getFieldName());
if (Objects.isNull(schemaElement)) {
continue;
}
dimensionFilter.setName(schemaElement.getName());
dimensionFilter.setBizName(schemaElement.getBizName());
dimensionFilter.setElementID(schemaElement.getId());
FilterOperatorEnum operatorEnum = FilterOperatorEnum.getSqlOperator(expression.getOperator());
dimensionFilter.setOperator(operatorEnum);
dimensionFilter.setFunction(expression.getFunction());
result.add(dimensionFilter);
}
return result;
}
private DateConf getDateInfo(List<FilterExpression> filterExpressions) {
List<FilterExpression> dateExpressions = filterExpressions.stream()
.filter(expression -> TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(expression.getFieldName()))
.collect(Collectors.toList());
if (org.springframework.util.CollectionUtils.isEmpty(dateExpressions)) {
return new DateConf();
}
DateConf dateInfo = new DateConf();
dateInfo.setDateMode(DateMode.BETWEEN);
FilterExpression firstExpression = dateExpressions.get(0);
FilterOperatorEnum firstOperator = FilterOperatorEnum.getSqlOperator(firstExpression.getOperator());
if (FilterOperatorEnum.EQUALS.equals(firstOperator) && Objects.nonNull(firstExpression.getFieldValue())) {
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
dateInfo.setDateMode(DateMode.BETWEEN);
return dateInfo;
}
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.GREATER_THAN,
FilterOperatorEnum.GREATER_THAN_EQUALS)) {
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
if (hasSecondDate(dateExpressions)) {
dateInfo.setEndDate(dateExpressions.get(1).getFieldValue().toString());
}
}
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.MINOR_THAN,
FilterOperatorEnum.MINOR_THAN_EQUALS)) {
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
if (hasSecondDate(dateExpressions)) {
dateInfo.setStartDate(dateExpressions.get(1).getFieldValue().toString());
}
}
return dateInfo;
}
private boolean containOperators(FilterExpression expression, FilterOperatorEnum firstOperator,
FilterOperatorEnum... operatorEnums) {
return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(expression.getFieldValue()));
}
private boolean hasSecondDate(List<FilterExpression> dateExpressions) {
return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue());
}
protected Map<String, SchemaElement> getNameToElement(Long modelId) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
List<SchemaElement> dimensions = semanticSchema.getDimensions();
List<SchemaElement> metrics = semanticSchema.getMetrics();
List<SchemaElement> allElements = Lists.newArrayList();
allElements.addAll(dimensions);
allElements.addAll(metrics);
//support alias
return allElements.stream()
.filter(schemaElement -> schemaElement.getModel().equals(modelId))
.flatMap(schemaElement -> {
Set<Pair<String, SchemaElement>> result = new HashSet<>();
result.add(Pair.of(schemaElement.getName(), schemaElement));
List<String> aliasList = schemaElement.getAlias();
if (!org.springframework.util.CollectionUtils.isEmpty(aliasList)) {
for (String alias : aliasList) {
result.add(Pair.of(alias, schemaElement));
}
}
return result.stream();
})
.collect(Collectors.toMap(pair -> pair.getLeft(), pair -> pair.getRight(), (value1, value2) -> value2));
}
protected void convertBizNameToName(QueryStructReq queryStructReq) {
SchemaService schemaService = ContextUtils.getBean(SchemaService.class);
Map<String, String> bizNameToName = schemaService.getSemanticSchema()
.getBizNameToName(queryStructReq.getModelId());
List<Order> orders = queryStructReq.getOrders();
if (CollectionUtils.isNotEmpty(orders)) {
for (Order order : orders) {
order.setColumn(bizNameToName.get(order.getColumn()));
}
}
List<Aggregator> aggregators = queryStructReq.getAggregators();
if (CollectionUtils.isNotEmpty(aggregators)) {
for (Aggregator aggregator : aggregators) {
aggregator.setColumn(bizNameToName.get(aggregator.getColumn()));
}
}
List<String> groups = queryStructReq.getGroups();
if (CollectionUtils.isNotEmpty(groups)) {
groups = groups.stream().map(group -> bizNameToName.get(group)).collect(Collectors.toList());
queryStructReq.setGroups(groups);
}
List<Filter> dimensionFilters = queryStructReq.getDimensionFilters();
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
dimensionFilters.stream().forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
}
List<Filter> metricFilters = queryStructReq.getMetricFilters();
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
metricFilters.stream().forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
}
queryStructReq.setModelName(parseInfo.getModelName());
}
protected void initS2SqlByStruct() {
QueryStructReq queryStructReq = convertQueryStruct();
convertBizNameToName(queryStructReq);
QueryS2QLReq queryS2QLReq = queryStructReq.convert(queryStructReq);
parseInfo.getSqlInfo().setS2QL(queryS2QLReq.getSql());
parseInfo.getSqlInfo().setLogicSql(queryS2QLReq.getSql());
}
}

View File

@@ -10,7 +10,6 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.corrector.CorrectorService;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
@@ -30,7 +29,6 @@ import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
@@ -42,8 +40,6 @@ public class MetricInterpretQuery extends PluginSemanticQuery {
public static final String QUERY_MODE = "METRIC_INTERPRET";
@Autowired
private CorrectorService correctorService;
public MetricInterpretQuery() {
QueryManager.register(this);
@@ -56,15 +52,13 @@ public class MetricInterpretQuery extends PluginSemanticQuery {
@Override
public QueryResult execute(User user) throws SqlParseException {
QueryStructReq queryStructReq = QueryReqBuilder.buildStructReq(parseInfo);
fillAggregator(queryStructReq, parseInfo.getMetrics());
queryStructReq.setNativeQuery(true);
QueryStructReq queryStructReq = convertQueryStruct();
SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
queryStructReq.setUseS2qlSwitch(optimizationConfig.isUseS2qlSwitch());
if (optimizationConfig.isUseS2qlSwitch()) {
correctorService.addS2QLAndLoginSql(queryStructReq, parseInfo);
queryStructReq.setS2QL(parseInfo.getSqlInfo().getS2QL());
queryStructReq.setS2QL(parseInfo.getSqlInfo().getQuerySql());
}
QueryResultWithSchemaResp queryResultWithSchemaResp = semanticInterpreter.queryByStruct(queryStructReq, user);
@@ -87,6 +81,18 @@ public class MetricInterpretQuery extends PluginSemanticQuery {
return queryResult;
}
@Override
public void initS2Sql(User user) {
initS2SqlByStruct();
}
protected QueryStructReq convertQueryStruct() {
QueryStructReq queryStructReq = QueryReqBuilder.buildStructReq(parseInfo);
fillAggregator(queryStructReq, parseInfo.getMetrics());
queryStructReq.setNativeQuery(true);
return queryStructReq;
}
private String replaceText(String text, List<SchemaElementMatch> schemaElementMatches,
Map<String, String> replacedMap) {
if (CollectionUtils.isEmpty(schemaElementMatches)) {

View File

@@ -10,10 +10,7 @@ import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.semantic.api.model.enums.QueryTypeEnum;
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq;
import com.tencent.supersonic.semantic.api.query.request.QueryS2QLReq;
import java.util.ArrayList;
import java.util.List;
@@ -65,22 +62,9 @@ public class S2QLQuery extends PluginSemanticQuery {
return queryResult;
}
@Override
public SqlInfo explain(User user) {
public void initS2Sql(User user) {
SqlInfo sqlInfo = parseInfo.getSqlInfo();
ExplainSqlReq explainSqlReq = null;
try {
QueryS2QLReq queryS2QLReq = QueryReqBuilder.buildS2QLReq(sqlInfo.getLogicSql(), parseInfo.getModelId());
explainSqlReq = ExplainSqlReq.builder()
.queryTypeEnum(QueryTypeEnum.SQL)
.queryReq(queryS2QLReq)
.build();
ExplainResp explain = semanticInterpreter.explain(explainSqlReq, user);
sqlInfo.setQuerySql(explain.getSql());
} catch (Exception e) {
log.error("explain error explainSqlReq:{}", explainSqlReq, e);
}
return sqlInfo;
sqlInfo.setLogicSql(sqlInfo.getS2QL());
}
}

View File

@@ -1,26 +1,19 @@
package com.tencent.supersonic.chat.query.plugin;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.query.BaseSemanticQuery;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public abstract class PluginSemanticQuery implements SemanticQuery {
public abstract class PluginSemanticQuery extends BaseSemanticQuery {
protected SemanticParseInfo parseInfo = new SemanticParseInfo();
public void setParseInfo(SemanticParseInfo parseInfo) {
this.parseInfo = parseInfo;
}
public SemanticParseInfo getParseInfo() {
return parseInfo;
@Override
public String explain(User user) {
return null;
}
@Override
public SqlInfo explain(User user) {
return null;
public void initS2Sql(User user) {
}
}

View File

@@ -3,7 +3,6 @@ package com.tencent.supersonic.chat.query.rule;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
@@ -14,9 +13,8 @@ import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.corrector.CorrectorService;
import com.tencent.supersonic.chat.query.BaseSemanticQuery;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
@@ -24,14 +22,9 @@ import com.tencent.supersonic.chat.utils.QueryReqBuilder;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.semantic.api.model.enums.QueryTypeEnum;
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq;
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
import com.tencent.supersonic.semantic.api.query.request.QueryS2QLReq;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
@@ -44,9 +37,8 @@ import org.apache.commons.lang3.StringUtils;
@Slf4j
@ToString
public abstract class RuleSemanticQuery implements SemanticQuery, Serializable {
public abstract class RuleSemanticQuery extends BaseSemanticQuery {
protected SemanticParseInfo parseInfo = new SemanticParseInfo();
protected QueryMatcher queryMatcher = new QueryMatcher();
protected SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
@@ -59,6 +51,11 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable {
return queryMatcher.match(candidateElementMatches);
}
@Override
public void initS2Sql(User user) {
initS2SqlByStruct();
}
public void fillParseInfo(Long modelId, QueryContext queryContext, ChatContext chatContext) {
parseInfo.setQueryMode(getQueryMode());
@@ -203,10 +200,9 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable {
QueryStructReq queryStructReq = convertQueryStruct();
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
queryStructReq.setUseS2qlSwitch(optimizationConfig.isUseS2qlSwitch());
if (optimizationConfig.isUseS2qlSwitch()) {
CorrectorService correctorService = ContextUtils.getBean(CorrectorService.class);
correctorService.addS2QLAndLoginSql(queryStructReq, parseInfo);
queryStructReq.setS2QL(parseInfo.getSqlInfo().getS2QL());
queryStructReq.setLogicSql(parseInfo.getSqlInfo().getLogicSql());
}
QueryResultWithSchemaResp queryResp = semanticInterpreter.queryByStruct(queryStructReq, user);
@@ -227,29 +223,6 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable {
return queryResult;
}
@Override
public SqlInfo explain(User user) {
SqlInfo sqlInfo = parseInfo.getSqlInfo();
ExplainSqlReq explainSqlReq = null;
try {
QueryStructReq queryStructReq = convertQueryStruct();
CorrectorService correctorService = ContextUtils.getBean(CorrectorService.class);
correctorService.addS2QLAndLoginSql(queryStructReq, parseInfo);
QueryS2QLReq queryS2QLReq = QueryReqBuilder.buildS2QLReq(sqlInfo.getLogicSql(), parseInfo.getModelId());
explainSqlReq = ExplainSqlReq.builder()
.queryTypeEnum(QueryTypeEnum.SQL)
.queryReq(queryS2QLReq)
.build();
ExplainResp explain = semanticInterpreter.explain(explainSqlReq, user);
sqlInfo.setQuerySql(explain.getSql());
} catch (Exception e) {
log.error("explain error explainSqlReq:{}", explainSqlReq, e);
}
return sqlInfo;
}
protected boolean isMultiStructQuery() {
return false;
}

View File

@@ -6,7 +6,6 @@ import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.ParseTimeCostDO;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.common.util.JsonUtil;
@@ -15,6 +14,7 @@ import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
public class SqlInfoParseResponder implements ParseResponder {
@@ -64,11 +64,11 @@ public class SqlInfoParseResponder implements ParseResponder {
return;
}
semanticQuery.setParseInfo(parseInfo);
SqlInfo sqlInfo = semanticQuery.explain(queryReq.getUser());
if (Objects.isNull(sqlInfo)) {
String explainSql = semanticQuery.explain(queryReq.getUser());
if (StringUtils.isBlank(explainSql)) {
return;
}
parseInfo.setSqlInfo(sqlInfo);
parseInfo.getSqlInfo().setQuerySql(explainSql);
}
}

View File

@@ -30,7 +30,6 @@ import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.MetricInfo;
import com.tencent.supersonic.chat.api.pojo.response.ModelInfo;
import com.tencent.supersonic.chat.config.AggregatorConfig;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
import com.tencent.supersonic.common.pojo.DateConf;
@@ -427,9 +426,6 @@ public class SemanticService {
queryStructReq.setGroups(new ArrayList<>(Arrays.asList(dateField)));
queryStructReq.setDateInfo(getRatioDateConf(aggOperatorEnum, semanticParseInfo, results));
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
queryStructReq.setUseS2qlSwitch(optimizationConfig.isUseS2qlSwitch());
QueryResultWithSchemaResp queryResp = semanticInterpreter.queryByStruct(queryStructReq, user);
if (Objects.nonNull(queryResp) && !CollectionUtils.isEmpty(queryResp.getResultList())) {

View File

@@ -4,6 +4,7 @@ package com.tencent.supersonic.chat.service.impl;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.component.SemanticCorrector;
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
@@ -20,8 +21,6 @@ import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.persistence.dataobject.CostType;
@@ -112,14 +111,18 @@ public class QueryServiceImpl implements QueryService {
private QuerySelector querySelector = ComponentFactory.getQuerySelector();
private List<ParseResponder> parseResponders = ComponentFactory.getParseResponders();
private List<ExecuteResponder> executeResponders = ComponentFactory.getExecuteResponders();
private List<SemanticCorrector> semanticCorrectors = ComponentFactory.getSqlCorrections();
@Override
public ParseResp performParsing(QueryReq queryReq) {
Long parseTime = System.currentTimeMillis();
//1. build queryContext and chatContext
QueryContext queryCtx = new QueryContext(queryReq);
// in order to support multi-turn conversation, chat context is needed
ChatContext chatCtx = chatService.getOrCreateContext(queryReq.getChatId());
List<StatisticsDO> timeCostDOList = new ArrayList<>();
//2. mapper
schemaMappers.stream().forEach(mapper -> {
Long startTime = System.currentTimeMillis();
mapper.map(queryCtx);
@@ -127,6 +130,8 @@ public class QueryServiceImpl implements QueryService {
.interfaceName(mapper.getClass().getSimpleName()).type(CostType.MAPPER.getType()).build());
log.info("{} result:{}", mapper.getClass().getSimpleName(), JsonUtil.toString(queryCtx));
});
//3. parser
semanticParsers.stream().forEach(parser -> {
Long startTime = System.currentTimeMillis();
parser.parse(queryCtx, chatCtx);
@@ -134,17 +139,31 @@ public class QueryServiceImpl implements QueryService {
.interfaceName(parser.getClass().getSimpleName()).type(CostType.PARSER.getType()).build());
log.info("{} result:{}", parser.getClass().getSimpleName(), JsonUtil.toString(queryCtx));
});
//4. corrector
List<SemanticQuery> candidateQueries = queryCtx.getCandidateQueries();
if (CollectionUtils.isNotEmpty(candidateQueries)) {
for (SemanticQuery semanticQuery : candidateQueries) {
semanticQuery.initS2Sql(queryReq.getUser());
semanticCorrectors.stream().forEach(correction -> {
correction.correct(queryReq, semanticQuery.getParseInfo());
});
semanticQuery.updateParseInfo();
}
}
//5. generate parsing results.
ParseResp parseResult;
List<ChatParseDO> chatParseDOS = Lists.newArrayList();
if (queryCtx.getCandidateQueries().size() > 0) {
log.debug("pick before [{}]", queryCtx.getCandidateQueries().stream().collect(
if (candidateQueries.size() > 0) {
log.debug("pick before [{}]", candidateQueries.stream().collect(
Collectors.toList()));
List<SemanticQuery> selectedQueries = querySelector.select(queryCtx.getCandidateQueries(), queryReq);
List<SemanticQuery> selectedQueries = querySelector.select(candidateQueries, queryReq);
log.debug("pick after [{}]", selectedQueries.stream().collect(
Collectors.toList()));
List<SemanticParseInfo> selectedParses = convertParseInfo(selectedQueries);
List<SemanticParseInfo> candidateParses = convertParseInfo(queryCtx.getCandidateQueries());
List<SemanticParseInfo> candidateParses = convertParseInfo(candidateQueries);
candidateParses = getTop5CandidateParseInfo(selectedParses, candidateParses);
parseResult = ParseResp.builder()
.chatId(queryReq.getChatId())
@@ -161,6 +180,7 @@ public class QueryServiceImpl implements QueryService {
.state(ParseResp.ParseState.FAILED)
.build();
}
//6. responders
for (ParseResponder parseResponder : parseResponders) {
Long startTime = System.currentTimeMillis();
parseResponder.fillResponse(parseResult, queryCtx, chatParseDOS);
@@ -315,14 +335,14 @@ public class QueryServiceImpl implements QueryService {
Set<String> removeWhereFieldNames = new HashSet<>();
Set<String> removeHavingFieldNames = new HashSet<>();
// replace where filter
updateFilters(filedNameToValueMap, whereExpressionList, queryData.getDimensionFilters(),
updateFilters(whereExpressionList, queryData.getDimensionFilters(),
parseInfo.getDimensionFilters(), addWhereConditions, removeWhereFieldNames);
updateDateInfo(queryData, parseInfo, filedNameToValueMap,
whereExpressionList, addWhereConditions, removeWhereFieldNames);
correctorSql = SqlParserReplaceHelper.replaceValue(correctorSql, filedNameToValueMap);
correctorSql = SqlParserRemoveHelper.removeWhereCondition(correctorSql, removeWhereFieldNames);
// replace having filter
updateFilters(havingFiledNameToValueMap, havingExpressionList, queryData.getDimensionFilters(),
updateFilters(havingExpressionList, queryData.getDimensionFilters(),
parseInfo.getDimensionFilters(), addHavingConditions, removeHavingFieldNames);
correctorSql = SqlParserReplaceHelper.replaceHavingValue(correctorSql, havingFiledNameToValueMap);
correctorSql = SqlParserRemoveHelper.removeHavingCondition(correctorSql, removeHavingFieldNames);
@@ -333,9 +353,9 @@ public class QueryServiceImpl implements QueryService {
log.info("correctorSql after replacing:{}", correctorSql);
parseInfo.getSqlInfo().setLogicSql(correctorSql);
semanticQuery.setParseInfo(parseInfo);
SqlInfo sqlInfo = semanticQuery.explain(user);
if (!Objects.isNull(sqlInfo)) {
parseInfo.setSqlInfo(sqlInfo);
String explainSql = semanticQuery.explain(user);
if (StringUtils.isNotBlank(explainSql)) {
parseInfo.getSqlInfo().setQuerySql(explainSql);
}
}
semanticQuery.setParseInfo(parseInfo);
@@ -432,8 +452,7 @@ public class QueryServiceImpl implements QueryService {
addConditions.add(comparisonExpression);
}
private void updateFilters(Map<String, Map<String, String>> filedNameToValueMap,
List<FilterExpression> filterExpressionList,
private void updateFilters(List<FilterExpression> filterExpressionList,
Set<QueryFilter> metricFilters,
Set<QueryFilter> contextMetricFilters,
List<Expression> addConditions,
@@ -624,11 +643,7 @@ public class QueryServiceImpl implements QueryService {
groups.add(dimensionValueReq.getBizName());
queryStructReq.setGroups(groups);
SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
queryStructReq.setUseS2qlSwitch(optimizationConfig.isUseS2qlSwitch());
QueryResultWithSchemaResp queryResultWithSchemaResp = semanticInterpreter.queryByStruct(queryStructReq, user);
return queryResultWithSchemaResp;
return semanticInterpreter.queryByStruct(queryStructReq, user);
}
}

View File

@@ -10,7 +10,6 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
import com.tencent.supersonic.chat.config.DefaultMetric;
import com.tencent.supersonic.chat.config.Dim4Dict;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.common.pojo.Aggregator;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DateConf;
@@ -19,7 +18,6 @@ import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import java.util.ArrayList;
@@ -52,13 +50,10 @@ public class DictQueryHelper {
public List<String> fetchDimValueSingle(Long modelId, DefaultMetric defaultMetricDesc, Dim4Dict dim4Dict,
User user) {
User user) {
List<String> data = new ArrayList<>();
QueryStructReq queryStructCmd = generateQueryStructCmd(modelId, defaultMetricDesc, dim4Dict);
try {
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
queryStructCmd.setUseS2qlSwitch(optimizationConfig.isUseS2qlSwitch());
QueryResultWithSchemaResp queryResultWithColumns = semanticInterpreter.queryByStruct(queryStructCmd, user);
log.info("fetchDimValueSingle sql:{}", queryResultWithColumns.getSql());
@@ -100,7 +95,7 @@ public class DictQueryHelper {
}
private List<String> generateFileData(List<Map<String, Object>> resultList, String nature, String dimName,
String metricName, Dim4Dict dim4Dict) {
String metricName, Dim4Dict dim4Dict) {
List<String> data = new ArrayList<>();
if (CollectionUtils.isEmpty(resultList)) {
return data;
@@ -125,7 +120,7 @@ public class DictQueryHelper {
}
private void constructDataLines(Map<String, Long> valueAndFrequencyPair, String nature,
List<String> data, Dim4Dict dim4Dict) {
List<String> data, Dim4Dict dim4Dict) {
valueAndFrequencyPair.forEach((dimValue, metric) -> {
if (metric > MAX_FREQUENCY) {
metric = MAX_FREQUENCY;

View File

@@ -59,8 +59,6 @@ public class QueryStructReq {
private Boolean nativeQuery = false;
private Cache cacheInfo;
private boolean useS2qlSwitch;
/**
* Later deleted for compatibility only
*/