(improvement)(chat)Add PostProcessor to do some logic after parser and corrector (#403)

* (improvement)(chat) Add PostProcessor to do some logic after parser and corrector

* (improvement)(chat) Add MetricCheckPostProcessor used to verify whether the dimensions involved in the query in metric mode can drill down on the metric

---------

Co-authored-by: jolunoluo
This commit is contained in:
LXW
2023-11-20 09:58:25 +08:00
committed by GitHub
parent dd115f9d37
commit 0143b0a1b2
24 changed files with 690 additions and 249 deletions

View File

@@ -44,4 +44,33 @@ public class ModelSchema {
}
}
public SchemaElement getElement(SchemaElementType elementType, String name) {
Optional<SchemaElement> element = Optional.empty();
switch (elementType) {
case ENTITY:
element = Optional.ofNullable(entity);
break;
case MODEL:
element = Optional.of(model);
break;
case METRIC:
element = metrics.stream().filter(e -> name.equals(e.getName())).findFirst();
break;
case DIMENSION:
element = dimensions.stream().filter(e -> name.equals(e.getName())).findFirst();
break;
case VALUE:
element = dimensionValues.stream().filter(e -> name.equals(e.getName())).findFirst();
break;
default:
}
if (element.isPresent()) {
return element.get();
} else {
return null;
}
}
}

View File

@@ -1,9 +1,15 @@
package com.tencent.supersonic.chat.api.pojo;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class RelateSchemaElement {
private Long dimensionId;

View File

@@ -1,21 +1,11 @@
package com.tencent.supersonic.chat.api.pojo.response;
import cn.hutool.core.collection.CollectionUtil;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import lombok.Data;
import lombok.Getter;
import lombok.Builder;
import lombok.NoArgsConstructor;
import lombok.AllArgsConstructor;
import java.util.List;
@Data
@Getter
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class ParseResp {
private Integer chatId;
private String queryText;
@@ -23,8 +13,7 @@ public class ParseResp {
private ParseState state;
private List<SemanticParseInfo> selectedParses = Lists.newArrayList();
private List<SemanticParseInfo> candidateParses = Lists.newArrayList();
private List<SolvedQueryRecallResp> similarSolvedQuery;
private ParseTimeCostDO parseTimeCost;
private ParseTimeCostDO parseTimeCost = new ParseTimeCostDO();
public enum ParseState {
COMPLETED,
@@ -32,12 +21,4 @@ public class ParseResp {
FAILED
}
public List<SemanticParseInfo> getSelectedParses() {
selectedParses = Lists.newArrayList();
if (CollectionUtil.isNotEmpty(candidateParses)) {
selectedParses.addAll(candidateParses);
candidateParses.clear();
}
return selectedParses;
}
}

View File

@@ -4,6 +4,12 @@ import lombok.Data;
@Data
public class ParseTimeCostDO {
private Long parseTime;
private Long sqlTime;
private long parseStartTime;
private long parseTime;
private long sqlTime;
public ParseTimeCostDO() {
this.parseStartTime = System.currentTimeMillis();
}
}

View File

@@ -1,100 +0,0 @@
package com.tencent.supersonic.chat.parser.rule;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.RelateSchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.query.rule.metric.MetricSemanticQuery;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import org.apache.commons.collections.CollectionUtils;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
public class MetricCheckParser implements SemanticParser {
@Override
public void parse(QueryContext queryContext, ChatContext chatContext) {
List<SemanticQuery> semanticQueries = queryContext.getCandidateQueries();
if (CollectionUtils.isEmpty(semanticQueries)) {
return;
}
semanticQueries.removeIf(this::removeQuery);
}
private boolean removeQuery(SemanticQuery semanticQuery) {
if (semanticQuery instanceof MetricSemanticQuery) {
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
List<SchemaElementMatch> schemaElementMatches = parseInfo.getElementMatches();
List<SchemaElementMatch> elementMatchFiltered =
filterMetricElement(schemaElementMatches, parseInfo.getModelId());
return 0 >= getMetricElementMatchCount(elementMatchFiltered);
}
return false;
}
private List<SchemaElementMatch> filterMetricElement(List<SchemaElementMatch> elementMatches, Long modelId) {
List<SchemaElementMatch> filterSchemaElementMatch = Lists.newArrayList();
SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
ModelSchema modelSchema = semanticInterpreter.getModelSchema(modelId, true);
Set<SchemaElement> metricElements = modelSchema.getMetrics();
Map<Long, SchemaElementMatch> valueElementMatchMap = getValueElementMap(elementMatches);
Map<Long, SchemaElement> metricMap = metricElements.stream()
.collect(Collectors.toMap(SchemaElement::getId, e -> e, (e1, e2) -> e2));
for (SchemaElementMatch schemaElementMatch : elementMatches) {
if (!SchemaElementType.METRIC.equals(schemaElementMatch.getElement().getType())) {
filterSchemaElementMatch.add(schemaElementMatch);
continue;
}
SchemaElement metric = metricMap.get(schemaElementMatch.getElement().getId());
List<Long> necessaryDimensionIds = getNecessaryDimensionIds(metric);
boolean flag = true;
for (Long necessaryDimensionId : necessaryDimensionIds) {
if (!valueElementMatchMap.containsKey(necessaryDimensionId)) {
flag = false;
break;
}
}
if (flag) {
filterSchemaElementMatch.add(schemaElementMatch);
}
}
return filterSchemaElementMatch;
}
private Map<Long, SchemaElementMatch> getValueElementMap(List<SchemaElementMatch> elementMatches) {
return elementMatches.stream()
.filter(elementMatch ->
SchemaElementType.VALUE.equals(elementMatch.getElement().getType()))
.collect(Collectors.toMap(elementMatch -> elementMatch.getElement().getId(), e -> e, (e1, e2) -> e1));
}
private long getMetricElementMatchCount(List<SchemaElementMatch> elementMatches) {
return elementMatches.stream().filter(elementMatch ->
SchemaElementType.METRIC.equals(elementMatch.getElement().getType()))
.count();
}
private List<Long> getNecessaryDimensionIds(SchemaElement metric) {
if (metric == null) {
return Lists.newArrayList();
}
List<RelateSchemaElement> relateSchemaElements = metric.getRelateSchemaElements();
if (CollectionUtils.isEmpty(relateSchemaElements)) {
return Lists.newArrayList();
}
return relateSchemaElements.stream()
.filter(RelateSchemaElement::isNecessary).map(RelateSchemaElement::getDimensionId)
.collect(Collectors.toList());
}
}

View File

@@ -19,7 +19,6 @@ public class RuleBasedParser implements SemanticParser {
new QueryModeParser(),
new ContextInheritParser(),
new AgentCheckParser(),
new MetricCheckParser(),
new TimeRangeParser(),
new AggregateTypeParser()
);

View File

@@ -4,7 +4,8 @@ public enum CostType {
MAPPER(1, "mapper"),
PARSER(2, "parser"),
QUERY(3, "query"),
PARSERRESPONDER(4, "responder");
PARSERRESPONDER(4, "responder"),
POSTPROCESSOR(5, "postprocessor");
private Integer type;
private String name;

View File

@@ -4,33 +4,31 @@ import com.github.pagehelper.PageHelper;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDOExample;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDOExample.Criteria;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.persistence.mapper.ChatParseMapper;
import com.tencent.supersonic.chat.persistence.mapper.ChatQueryDOMapper;
import com.tencent.supersonic.chat.persistence.mapper.custom.ShowCaseCustomMapper;
import com.tencent.supersonic.chat.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.PageUtils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Repository;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.stream.Collectors;
@Repository
@Primary
@@ -136,7 +134,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
List<SemanticParseInfo> candidateParses) {
Long queryId = createChatParse(parseResult, chatCtx, queryReq);
List<ChatParseDO> chatParseDOList = new ArrayList<>();
getChatParseDO(chatCtx, queryReq, queryId, 0, candidateParses, chatParseDOList);
getChatParseDO(chatCtx, queryReq, queryId, candidateParses, chatParseDOList);
chatParseMapper.batchSaveParseInfo(chatParseDOList);
return chatParseDOList;
}
@@ -148,11 +146,10 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
}
}
public void getChatParseDO(ChatContext chatCtx, QueryReq queryReq, Long queryId, int base,
public void getChatParseDO(ChatContext chatCtx, QueryReq queryReq, Long queryId,
List<SemanticParseInfo> parses, List<ChatParseDO> chatParseDOList) {
for (int i = 0; i < parses.size(); i++) {
ChatParseDO chatParseDO = new ChatParseDO();
parses.get(i).setId(base + i + 1);
chatParseDO.setChatId(Long.valueOf(chatCtx.getChatId()));
chatParseDO.setQuestionId(queryId);
chatParseDO.setQueryText(queryReq.getQueryText());
@@ -161,7 +158,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
if (i == 0) {
chatParseDO.setIsCandidate(0);
}
chatParseDO.setParseId(base + i + 1);
chatParseDO.setParseId(parses.get(i).getId());
chatParseDO.setCreateTime(new java.util.Date());
chatParseDO.setUserName(queryReq.getUser().getName());
chatParseDOList.add(chatParseDO);

View File

@@ -0,0 +1,202 @@
package com.tencent.supersonic.chat.postprocessor;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.RelateSchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.common.pojo.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserRemoveHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
/**
* MetricCheckPostProcessor used to verify whether the dimensions
* involved in the query in metric mode can drill down on the metric.
*/
public class MetricCheckPostProcessor implements PostProcessor {
@Override
public void process(QueryContext queryContext) {
List<SemanticQuery> semanticQueries = queryContext.getCandidateQueries();
for (SemanticQuery semanticQuery : semanticQueries) {
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
if (!QueryType.METRIC.equals(parseInfo.getQueryType())) {
continue;
}
SchemaService schemaService = ContextUtils.getBean(SchemaService.class);
ModelSchema modelSchema = schemaService.getModelSchema(parseInfo.getModelId());
String correctSqlProcessed = processCorrectSql(parseInfo.getSqlInfo().getCorrectS2SQL(), modelSchema);
parseInfo.getSqlInfo().setCorrectS2SQL(correctSqlProcessed);
}
semanticQueries.removeIf(semanticQuery -> {
if (!QueryType.METRIC.equals(semanticQuery.getParseInfo().getQueryType())) {
return false;
}
String correctSql = semanticQuery.getParseInfo().getSqlInfo().getCorrectS2SQL();
if (StringUtils.isBlank(correctSql)) {
return false;
}
return CollectionUtils.isEmpty(SqlParserSelectHelper.getAggregateFields(correctSql));
});
}
public String processCorrectSql(String correctSql, ModelSchema modelSchema) {
List<String> groupByFields = SqlParserSelectHelper.getGroupByFields(correctSql);
List<String> metricFields = SqlParserSelectHelper.getAggregateFields(correctSql);
List<String> whereFields = SqlParserSelectHelper.getWhereFields(correctSql);
List<String> dimensionFields = getDimensionFields(groupByFields, whereFields);
if (CollectionUtils.isEmpty(metricFields) || StringUtils.isBlank(correctSql)) {
return correctSql;
}
Set<String> metricToRemove = Sets.newHashSet();
Set<String> groupByToRemove = Sets.newHashSet();
Set<String> whereFieldsToRemove = Sets.newHashSet();
for (String metricName : metricFields) {
SchemaElement metricElement = modelSchema.getElement(SchemaElementType.METRIC, metricName);
if (metricElement == null) {
metricToRemove.add(metricName);
}
if (!checkNecessaryDimension(metricElement, modelSchema, dimensionFields)) {
metricToRemove.add(metricName);
}
}
for (String dimensionName : whereFields) {
if (TimeDimensionEnum.getNameList().contains(dimensionName)) {
continue;
}
if (!checkInModelSchema(dimensionName, SchemaElementType.DIMENSION, modelSchema)) {
whereFieldsToRemove.add(dimensionName);
}
if (!checkDrillDownDimension(dimensionName, metricFields, modelSchema)) {
whereFieldsToRemove.add(dimensionName);
}
}
for (String dimensionName : groupByFields) {
if (TimeDimensionEnum.getNameList().contains(dimensionName)) {
continue;
}
if (!checkInModelSchema(dimensionName, SchemaElementType.DIMENSION, modelSchema)) {
groupByToRemove.add(dimensionName);
}
if (!checkDrillDownDimension(dimensionName, metricFields, modelSchema)) {
groupByToRemove.add(dimensionName);
}
}
return removeFieldInSql(correctSql, metricToRemove, groupByToRemove, whereFieldsToRemove);
}
/**
* To check whether the dimension bound to the metric exists,
* eg: metric like UV is calculated in a certain dimension, it cannot be used on other dimensions.
*/
private boolean checkNecessaryDimension(SchemaElement metric, ModelSchema modelSchema,
List<String> dimensionFields) {
List<String> necessaryDimensions = getNecessaryDimensionNames(metric, modelSchema);
if (CollectionUtils.isEmpty(necessaryDimensions)) {
return true;
}
for (String dimension : necessaryDimensions) {
if (!dimensionFields.contains(dimension)) {
return false;
}
}
return true;
}
/**
* To check whether the dimension can drill down the metric,
* eg: some descriptive dimensions are not suitable as drill-down dimensions
*/
private boolean checkDrillDownDimension(String dimensionName, List<String> metrics,
ModelSchema modelSchema) {
List<SchemaElement> metricElements = modelSchema.getMetrics().stream()
.filter(schemaElement -> metrics.contains(schemaElement.getName()))
.collect(Collectors.toList());
if (CollectionUtils.isEmpty(metricElements)) {
return false;
}
List<String> relateDimensions = metricElements.stream()
.filter(schemaElement -> !CollectionUtils.isEmpty(schemaElement.getRelateSchemaElements()))
.map(schemaElement -> schemaElement.getRelateSchemaElements().stream()
.map(RelateSchemaElement::getDimensionId).collect(Collectors.toList()))
.flatMap(Collection::stream)
.map(id -> convertDimensionIdToName(id, modelSchema))
.filter(Objects::nonNull)
.collect(Collectors.toList());
//if no metric has drill down dimension, return true
if (CollectionUtils.isEmpty(relateDimensions)) {
return true;
}
//if this dimension not in relate drill-down dimensions, return false
return relateDimensions.contains(dimensionName);
}
private List<String> getNecessaryDimensionNames(SchemaElement metric, ModelSchema modelSchema) {
List<Long> necessaryDimensionIds = getNecessaryDimensions(metric);
return necessaryDimensionIds.stream().map(id -> convertDimensionIdToName(id, modelSchema))
.filter(Objects::nonNull).collect(Collectors.toList());
}
private List<Long> getNecessaryDimensions(SchemaElement metric) {
if (metric == null) {
return Lists.newArrayList();
}
List<RelateSchemaElement> relateSchemaElements = metric.getRelateSchemaElements();
if (CollectionUtils.isEmpty(relateSchemaElements)) {
return Lists.newArrayList();
}
return relateSchemaElements.stream()
.filter(RelateSchemaElement::isNecessary).map(RelateSchemaElement::getDimensionId)
.collect(Collectors.toList());
}
private List<String> getDimensionFields(List<String> groupByFields, List<String> whereFields) {
List<String> dimensionFields = Lists.newArrayList();
if (!CollectionUtils.isEmpty(groupByFields)) {
dimensionFields.addAll(groupByFields);
}
if (!CollectionUtils.isEmpty(whereFields)) {
dimensionFields.addAll(whereFields);
}
return dimensionFields;
}
private String convertDimensionIdToName(Long id, ModelSchema modelSchema) {
SchemaElement schemaElement = modelSchema.getElement(SchemaElementType.DIMENSION, id);
if (schemaElement == null) {
return null;
}
return schemaElement.getName();
}
private boolean checkInModelSchema(String name, SchemaElementType type, ModelSchema modelSchema) {
SchemaElement schemaElement = modelSchema.getElement(type, name);
return schemaElement != null;
}
private static String removeFieldInSql(String sql, Set<String> metricToRemove,
Set<String> dimensionByToRemove, Set<String> whereFieldsToRemove) {
sql = SqlParserRemoveHelper.removeWhereCondition(sql, whereFieldsToRemove);
sql = SqlParserRemoveHelper.removeSelect(sql, metricToRemove);
sql = SqlParserRemoveHelper.removeSelect(sql, dimensionByToRemove);
sql = SqlParserRemoveHelper.removeGroupBy(sql, dimensionByToRemove);
sql = SqlParserRemoveHelper.removeNumberCondition(sql);
return sql;
}
}

View File

@@ -0,0 +1,29 @@
package com.tencent.supersonic.chat.postprocessor;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.service.ParseInfoService;
import com.tencent.supersonic.common.util.ContextUtils;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.stream.Collectors;
/**
* update parse info from correct sql
*/
public class ParseInfoUpdateProcessor implements PostProcessor {
@Override
public void process(QueryContext queryContext) {
List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries();
if (CollectionUtils.isEmpty(candidateQueries)) {
return;
}
ParseInfoService parseInfoService = ContextUtils.getBean(ParseInfoService.class);
List<SemanticParseInfo> candidateParses = candidateQueries.stream()
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
candidateParses.forEach(parseInfoService::updateParseInfo);
}
}

View File

@@ -0,0 +1,12 @@
package com.tencent.supersonic.chat.postprocessor;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
/**
* A post processor do some logic after parser and corrector
*/
public interface PostProcessor {
void process(QueryContext queryContext);
}

View File

@@ -4,13 +4,13 @@ import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import java.util.List;
import java.util.ArrayList;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
@Slf4j
@Component
@@ -30,6 +30,7 @@ public class QueryRanker {
} else {
selectedQueries = getTopCandidateQuery(candidateQueries);
}
generateParseInfoId(selectedQueries);
log.debug("pick after [{}]", selectedQueries);
return selectedQueries;
}
@@ -48,6 +49,13 @@ public class QueryRanker {
.collect(Collectors.toList());
}
private void generateParseInfoId(List<SemanticQuery> semanticQueries) {
for (int i = 0; i < semanticQueries.size(); i++) {
SemanticQuery query = semanticQueries.get(i);
query.getParseInfo().setId(i + 1);
}
}
private boolean checkFullyInherited(SemanticQuery query) {
SemanticParseInfo parseInfo = query.getParseInfo();
if (!(query instanceof RuleSemanticQuery)) {

View File

@@ -1,27 +1,30 @@
package com.tencent.supersonic.chat.responder.parse;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.llm.interpret.MetricInterpretQuery;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.util.ContextUtils;
import java.util.List;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.stream.Collectors;
public class EntityInfoParseResponder implements ParseResponder {
@Override
public void fillResponse(ParseResp parseResp, QueryContext queryContext,
List<ChatParseDO> chatParseDOS) {
List<SemanticParseInfo> selectedParses = parseResp.getCandidateParses();
if (CollectionUtils.isEmpty(selectedParses)) {
public void fillResponse(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
List<SemanticQuery> semanticQueries = queryContext.getCandidateQueries();
if (CollectionUtils.isEmpty(semanticQueries)) {
return;
}
List<SemanticParseInfo> selectedParses = semanticQueries.stream().map(SemanticQuery::getParseInfo)
.collect(Collectors.toList());
QueryReq queryReq = queryContext.getRequest();
selectedParses.forEach(parseInfo -> {
String queryMode = parseInfo.getQueryMode();

View File

@@ -0,0 +1,37 @@
package com.tencent.supersonic.chat.responder.parse;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.service.ChatService;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j;
import java.util.List;
import java.util.stream.Collectors;
@Slf4j
public class ParseRespBuildParseResponder implements ParseResponder {
@Override
public void fillResponse(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
QueryReq queryReq = queryContext.getRequest();
parseResp.setChatId(queryReq.getChatId());
parseResp.setQueryText(queryReq.getQueryText());
List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries();
if (candidateQueries.size() > 0) {
List<SemanticParseInfo> candidateParses = candidateQueries.stream()
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
parseResp.setCandidateParses(candidateParses);
parseResp.setState(ParseResp.ParseState.COMPLETED);
parseResp.setCandidateParses(candidateParses);
ChatService chatService = ContextUtils.getBean(ChatService.class);
chatService.batchAddParse(chatContext, queryReq, parseResp);
} else {
parseResp.setState(ParseResp.ParseState.FAILED);
}
}
}

View File

@@ -1,12 +1,11 @@
package com.tencent.supersonic.chat.responder.parse;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import java.util.List;
public interface ParseResponder {
void fillResponse(ParseResp parseResp, QueryContext queryContext, List<ChatParseDO> chatParseDOS);
void fillResponse(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext);
}

View File

@@ -0,0 +1,20 @@
package com.tencent.supersonic.chat.responder.parse;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class ParseTimeParseResponder implements ParseResponder {
@Override
public void fillResponse(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
long parseStartTime = parseResp.getParseTimeCost().getParseStartTime();
parseResp.getParseTimeCost().setParseTime(
System.currentTimeMillis() - parseStartTime - parseResp.getParseTimeCost().getSqlTime());
}
}

View File

@@ -0,0 +1,24 @@
package com.tencent.supersonic.chat.responder.parse;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.query.QueryRanker;
import com.tencent.supersonic.common.util.ContextUtils;
import java.util.List;
/**
* Rank queries by score.
*/
public class QueryRankParseResponder implements ParseResponder {
@Override
public void fillResponse(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries();
QueryRanker queryRanker = ContextUtils.getBean(QueryRanker.class);
candidateQueries = queryRanker.rank(candidateQueries);
queryContext.setCandidateQueries(candidateQueries);
}
}

View File

@@ -1,50 +1,32 @@
package com.tencent.supersonic.chat.responder.parse;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.ParseTimeCostDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.common.util.JsonUtil;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
public class SqlInfoParseResponder implements ParseResponder {
@Override
public void fillResponse(ParseResp parseResp, QueryContext queryContext,
List<ChatParseDO> chatParseDOS) {
public void fillResponse(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
QueryReq queryReq = queryContext.getRequest();
Long startTime = System.currentTimeMillis();
addSqlInfo(queryReq, parseResp.getCandidateParses());
parseResp.setParseTimeCost(new ParseTimeCostDO());
parseResp.getParseTimeCost().setSqlTime(System.currentTimeMillis() - startTime);
if (!CollectionUtils.isEmpty(chatParseDOS)) {
Map<Integer, ChatParseDO> chatParseDOMap = chatParseDOS.stream()
.collect(Collectors.toMap(ChatParseDO::getParseId,
Function.identity(), (oldValue, newValue) -> newValue));
updateParseInfo(chatParseDOMap, parseResp.getCandidateParses());
}
}
private void updateParseInfo(Map<Integer, ChatParseDO> chatParseDOMap, List<SemanticParseInfo> parseInfos) {
if (CollectionUtils.isEmpty(parseInfos)) {
List<SemanticQuery> semanticQueries = queryContext.getCandidateQueries();
if (CollectionUtils.isEmpty(semanticQueries)) {
return;
}
for (SemanticParseInfo parseInfo : parseInfos) {
ChatParseDO chatParseDO = chatParseDOMap.get(parseInfo.getId());
if (chatParseDO != null) {
chatParseDO.setParseInfo(JsonUtil.toString(parseInfo));
}
}
List<SemanticParseInfo> selectedParses = semanticQueries.stream().map(SemanticQuery::getParseInfo)
.collect(Collectors.toList());
long startTime = System.currentTimeMillis();
addSqlInfo(queryReq, selectedParses);
parseResp.getParseTimeCost().setSqlTime(System.currentTimeMillis() - startTime);
}
private void addSqlInfo(QueryReq queryReq, List<SemanticParseInfo> semanticParseInfos) {

View File

@@ -26,14 +26,13 @@ import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.persistence.dataobject.CostType;
import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO;
import com.tencent.supersonic.chat.postprocessor.PostProcessor;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.QueryRanker;
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.responder.execute.ExecuteResponder;
import com.tencent.supersonic.chat.responder.parse.ParseResponder;
import com.tencent.supersonic.chat.service.ChatService;
import com.tencent.supersonic.chat.service.ParseInfoService;
import com.tencent.supersonic.chat.service.QueryService;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.chat.service.StatisticsService;
@@ -61,15 +60,6 @@ import com.tencent.supersonic.knowledge.utils.HanlpHelper;
import com.tencent.supersonic.knowledge.utils.NatureHelper;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.LongValue;
@@ -85,7 +75,6 @@ import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
import net.sf.jsqlparser.schema.Column;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.compress.utils.Lists;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
@@ -93,6 +82,16 @@ import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Component;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.stream.Collectors;
@Service
@Component("chatQueryService")
@Primary
@@ -105,23 +104,20 @@ public class QueryServiceImpl implements QueryService {
private StatisticsService statisticsService;
@Autowired
private SolvedQueryManager solvedQueryManager;
@Autowired
private ParseInfoService parseInfoService;
@Autowired
private QueryRanker queryRanker;
@Value("${time.threshold: 100}")
private Integer timeThreshold;
private List<SchemaMapper> schemaMappers = ComponentFactory.getSchemaMappers();
private List<SemanticParser> semanticParsers = ComponentFactory.getSemanticParsers();
private List<PostProcessor> postProcessors = ComponentFactory.getPostProcessors();
private List<ParseResponder> parseResponders = ComponentFactory.getParseResponders();
private List<ExecuteResponder> executeResponders = ComponentFactory.getExecuteResponders();
private List<SemanticCorrector> semanticCorrectors = ComponentFactory.getSqlCorrections();
@Override
public ParseResp performParsing(QueryReq queryReq) {
Long parseTime = System.currentTimeMillis();
ParseResp parseResult = new ParseResp();
//1. build queryContext and chatContext
QueryContext queryCtx = new QueryContext(queryReq);
// in order to support multi-turn conversation, chat context is needed
@@ -129,16 +125,16 @@ public class QueryServiceImpl implements QueryService {
List<StatisticsDO> timeCostDOList = new ArrayList<>();
//2. mapper
schemaMappers.stream().forEach(mapper -> {
Long startTime = System.currentTimeMillis();
schemaMappers.forEach(mapper -> {
long startTime = System.currentTimeMillis();
mapper.map(queryCtx);
timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime))
.interfaceName(mapper.getClass().getSimpleName()).type(CostType.MAPPER.getType()).build());
});
//3. parser
semanticParsers.stream().forEach(parser -> {
Long startTime = System.currentTimeMillis();
semanticParsers.forEach(parser -> {
long startTime = System.currentTimeMillis();
parser.parse(queryCtx, chatCtx);
timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime))
.interfaceName(parser.getClass().getSimpleName()).type(CostType.PARSER.getType()).build());
@@ -153,46 +149,34 @@ public class QueryServiceImpl implements QueryService {
if (semanticQuery instanceof RuleSemanticQuery) {
continue;
}
semanticCorrectors.stream().forEach(correction -> {
semanticCorrectors.forEach(correction -> {
correction.correct(queryReq, semanticQuery.getParseInfo());
});
}
}
//5. generate parsing results.
ParseResp parseResult;
List<ChatParseDO> chatParseDOS = Lists.newArrayList();
if (candidateQueries.size() > 0) {
candidateQueries = queryRanker.rank(candidateQueries);
List<SemanticParseInfo> candidateParses = candidateQueries.stream()
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
candidateParses.forEach(parseInfo -> parseInfoService.updateParseInfo(parseInfo));
parseResult = ParseResp.builder().chatId(queryReq.getChatId()).queryText(queryReq.getQueryText())
.state(ParseResp.ParseState.COMPLETED).candidateParses(candidateParses)
.build();
chatParseDOS = chatService.batchAddParse(chatCtx, queryReq, parseResult);
} else {
parseResult = ParseResp.builder()
.chatId(queryReq.getChatId())
.queryText(queryReq.getQueryText())
.state(ParseResp.ParseState.FAILED)
.build();
}
//5. postProcessor
postProcessors.forEach(postProcessor -> {
long startTime = System.currentTimeMillis();
postProcessor.process(queryCtx);
timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime))
.interfaceName(postProcessor.getClass().getSimpleName())
.type(CostType.POSTPROCESSOR.getType()).build());
});
//6. responders
for (ParseResponder parseResponder : parseResponders) {
Long startTime = System.currentTimeMillis();
parseResponder.fillResponse(parseResult, queryCtx, chatParseDOS);
parseResponders.forEach(parseResponder -> {
long startTime = System.currentTimeMillis();
parseResponder.fillResponse(parseResult, queryCtx, chatCtx);
timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime))
.interfaceName(parseResponder.getClass().getSimpleName())
.type(CostType.PARSERRESPONDER.getType()).build());
}
});
if (Objects.nonNull(parseResult.getQueryId()) && timeCostDOList.size() > 0) {
saveInfo(timeCostDOList, queryReq.getQueryText(), parseResult.getQueryId(),
saveTimeCostInfo(timeCostDOList, queryReq.getQueryText(), parseResult.getQueryId(),
queryReq.getUser().getName(), queryReq.getChatId().longValue());
}
chatService.updateChatParse(chatParseDOS);
parseResult.getParseTimeCost().setParseTime(
System.currentTimeMillis() - parseTime - parseResult.getParseTimeCost().getSqlTime());
return parseResult;
}
@@ -220,7 +204,7 @@ public class QueryServiceImpl implements QueryService {
timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime))
.interfaceName(semanticQuery.getClass().getSimpleName()).type(CostType.QUERY.getType()).build());
queryResult.setQueryTimeCost(timeCostDOList.get(0).getCost().longValue());
saveInfo(timeCostDOList, queryReq.getQueryText(), queryReq.getQueryId(),
saveTimeCostInfo(timeCostDOList, queryReq.getQueryText(), queryReq.getQueryId(),
queryReq.getUser().getName(), queryReq.getChatId().longValue());
queryResult.setChatContext(parseInfo);
// update chat context after a successful semantic query
@@ -242,7 +226,7 @@ public class QueryServiceImpl implements QueryService {
}
// save time cost data
private void saveInfo(List<StatisticsDO> timeCostDOList,
private void saveTimeCostInfo(List<StatisticsDO> timeCostDOList,
String queryText, Long queryId,
String userName, Long chatId) {
List<StatisticsDO> list = timeCostDOList.stream()

View File

@@ -1,19 +1,18 @@
package com.tencent.supersonic.chat.utils;
import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.component.SemanticCorrector;
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.component.SemanticCorrector;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import com.tencent.supersonic.chat.parser.llm.s2sql.ModelResolver;
import com.tencent.supersonic.chat.postprocessor.PostProcessor;
import com.tencent.supersonic.chat.responder.execute.ExecuteResponder;
import com.tencent.supersonic.chat.responder.parse.ParseResponder;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.core.io.support.SpringFactoriesLoader;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
public class ComponentFactory {
@@ -21,6 +20,7 @@ public class ComponentFactory {
private static List<SemanticParser> semanticParsers = new ArrayList<>();
private static List<SemanticCorrector> s2SQLCorrections = new ArrayList<>();
private static SemanticInterpreter semanticInterpreter;
private static List<PostProcessor> postProcessors = new ArrayList<>();
private static List<ParseResponder> parseResponders = new ArrayList<>();
private static List<ExecuteResponder> executeResponders = new ArrayList<>();
private static ModelResolver modelResolver;
@@ -37,6 +37,10 @@ public class ComponentFactory {
s2SQLCorrections) : s2SQLCorrections;
}
public static List<PostProcessor> getPostProcessors() {
return CollectionUtils.isEmpty(postProcessors) ? init(PostProcessor.class, postProcessors) : postProcessors;
}
public static List<ParseResponder> getParseResponders() {
return CollectionUtils.isEmpty(parseResponders) ? init(ParseResponder.class, parseResponders) : parseResponders;
}

View File

@@ -0,0 +1,124 @@
package com.tencent.supersonic.chat.postprocessor;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.RelateSchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.Set;
class MetricCheckPostProcessorTest {
@Test
void testProcessCorrectSql_necessaryDimension_groupBy() {
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
String correctSql = "select 用户名, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 用户名";
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql, mockModelSchema());
String expectedProcessedSql = "SELECT 用户名, sum(访问次数) FROM 超音数 GROUP BY 用户名";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
@Test
void testProcessCorrectSql_necessaryDimension_where() {
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
String correctSql = "select 用户名, sum(访问次数), count(distinct 访问用户数) from 超音数 where 部门 = 'HR' group by 用户名";
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql, mockModelSchema());
String expectedProcessedSql = "SELECT 用户名, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 "
+ "WHERE 部门 = 'HR' GROUP BY 用户名";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
@Test
void testProcessCorrectSql_dimensionNotDrillDown_groupBy() {
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
String correctSql = "select 页面, 部门, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 页面, 部门";
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql, mockModelSchema());
String expectedProcessedSql = "SELECT 部门, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 GROUP BY 部门";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
@Test
void testProcessCorrectSql_dimensionNotDrillDown_where() {
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
String correctSql = "select 部门, sum(访问次数), count(distinct 访问用户数) from 超音数 where 页面 = 'P1' group by 部门";
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql, mockModelSchema());
String expectedProcessedSql = "SELECT 部门, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 GROUP BY 部门";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
@Test
void testProcessCorrectSql_dimensionNotDrillDown_necessaryDimension() {
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
String correctSql = "select 页面, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 页面";
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql, mockModelSchema());
String expectedProcessedSql = "SELECT sum(访问次数) FROM 超音数";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
@Test
void testProcessCorrectSql_dimensionDrillDown() {
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
String correctSql = "select 用户名, 部门, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 用户名, 部门";
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql, mockModelSchema());
String expectedProcessedSql = "SELECT 用户名, 部门, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 GROUP BY 用户名, 部门";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
@Test
void testProcessCorrectSql_noDrillDownDimensionSetting() {
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
String correctSql = "select 页面, 用户名, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 页面, 用户名";
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql,
mockModelSchemaNoDimensionSetting());
String expectedProcessedSql = "SELECT 页面, 用户名, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 GROUP BY 页面, 用户名";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
/**
* 访问次数 drill down dimension is 用户名 and 部门
* 访问用户数 drill down dimension is 部门, and 部门 is necessary, 部门 need in select and group by or where expressions
*/
private ModelSchema mockModelSchema() {
ModelSchema modelSchema = new ModelSchema();
Set<SchemaElement> metrics = Sets.newHashSet(
mockElement(1L, "访问次数", SchemaElementType.METRIC,
Lists.newArrayList(RelateSchemaElement.builder().dimensionId(2L).isNecessary(false).build(),
RelateSchemaElement.builder().dimensionId(1L).isNecessary(false).build())),
mockElement(2L, "访问用户数", SchemaElementType.METRIC,
Lists.newArrayList(RelateSchemaElement.builder().dimensionId(2L).isNecessary(true).build()))
);
modelSchema.setMetrics(metrics);
modelSchema.setDimensions(mockDimensions());
return modelSchema;
}
private ModelSchema mockModelSchemaNoDimensionSetting() {
ModelSchema modelSchema = new ModelSchema();
Set<SchemaElement> metrics = Sets.newHashSet(
mockElement(1L, "访问次数", SchemaElementType.METRIC, Lists.newArrayList()),
mockElement(2L, "访问用户数", SchemaElementType.METRIC, Lists.newArrayList())
);
modelSchema.setMetrics(metrics);
modelSchema.setDimensions(mockDimensions());
return modelSchema;
}
private Set<SchemaElement> mockDimensions() {
return Sets.newHashSet(
mockElement(1L, "用户名", SchemaElementType.DIMENSION, Lists.newArrayList()),
mockElement(2L, "部门", SchemaElementType.DIMENSION, Lists.newArrayList()),
mockElement(3L, "页面", SchemaElementType.DIMENSION, Lists.newArrayList())
);
}
private SchemaElement mockElement(Long id, String name, SchemaElementType type,
List<RelateSchemaElement> relateSchemaElements) {
return SchemaElement.builder().id(id).name(name).type(type)
.relateSchemaElements(relateSchemaElements).build();
}
}

View File

@@ -1,30 +1,37 @@
package com.tencent.supersonic.common.util.jsqlparser;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.BinaryExpression;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.Parenthesis;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.Parenthesis;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.expression.operators.relational.LikeExpression;
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
import net.sf.jsqlparser.expression.operators.relational.LikeExpression;
import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.statement.select.GroupByElement;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SelectBody;
import net.sf.jsqlparser.statement.select.SelectExpressionItem;
import net.sf.jsqlparser.statement.select.SelectItem;
import net.sf.jsqlparser.statement.select.SelectVisitorAdapter;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.Objects;
import java.util.Set;
/**
* Sql Parser remove Helper
@@ -56,6 +63,9 @@ public class SqlParserRemoveHelper {
}
public static String removeNumberCondition(String sql) {
Select selectStatement = SqlParserSelectHelper.getSelect(sql);
if (selectStatement == null) {
return sql;
}
SelectBody selectBody = selectStatement.getSelectBody();
if (!(selectBody instanceof PlainSelect)) {
@@ -191,6 +201,55 @@ public class SqlParserRemoveHelper {
return selectStatement.toString();
}
public static String removeSelect(String sql, Set<String> fields) {
Select selectStatement = SqlParserSelectHelper.getSelect(sql);
if (selectStatement == null) {
return sql;
}
SelectBody selectBody = selectStatement.getSelectBody();
if (!(selectBody instanceof PlainSelect)) {
return sql;
}
List<SelectItem> selectItems = ((PlainSelect) selectBody).getSelectItems();
selectItems.removeIf(selectItem -> {
if (selectItem instanceof SelectExpressionItem) {
SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem;
String columnName = SqlParserSelectHelper.getColumnName(selectExpressionItem.getExpression());
return fields.contains(columnName);
}
return false;
});
return selectStatement.toString();
}
public static String removeGroupBy(String sql, Set<String> fields) {
Select selectStatement = SqlParserSelectHelper.getSelect(sql);
if (selectStatement == null) {
return sql;
}
SelectBody selectBody = selectStatement.getSelectBody();
if (!(selectBody instanceof PlainSelect)) {
return sql;
}
GroupByElement groupByElement = ((PlainSelect) selectBody).getGroupBy();
if (groupByElement == null) {
return sql;
}
ExpressionList groupByExpressionList = groupByElement.getGroupByExpressionList();
groupByExpressionList.getExpressions().removeIf(expression -> {
if (expression instanceof Column) {
Column column = (Column) expression;
return fields.contains(column.getColumnName());
}
return false;
});
if (CollectionUtils.isEmpty(groupByExpressionList.getExpressions())) {
((PlainSelect) selectBody).setGroupByElement(null);
}
return selectStatement.toString();
}
private static Expression filteredWhereExpression(Expression where) {
if (Objects.isNull(where)) {
return null;

View File

@@ -31,9 +31,16 @@ com.tencent.supersonic.auth.authentication.interceptor.AuthenticationInterceptor
com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\
com.tencent.supersonic.auth.authentication.adaptor.DefaultUserAdaptor
com.tencent.supersonic.chat.postprocessor.PostProcessor=\
com.tencent.supersonic.chat.postprocessor.MetricCheckPostProcessor, \
com.tencent.supersonic.chat.postprocessor.ParseInfoUpdateProcessor
com.tencent.supersonic.chat.responder.parse.ParseResponder=\
com.tencent.supersonic.chat.responder.parse.QueryRankParseResponder, \
com.tencent.supersonic.chat.responder.parse.EntityInfoParseResponder, \
com.tencent.supersonic.chat.responder.parse.SqlInfoParseResponder
com.tencent.supersonic.chat.responder.parse.SqlInfoParseResponder, \
com.tencent.supersonic.chat.responder.parse.ParseTimeParseResponder, \
com.tencent.supersonic.chat.responder.parse.ParseRespBuildParseResponder
com.tencent.supersonic.chat.responder.execute.ExecuteResponder=\
com.tencent.supersonic.chat.responder.execute.EntityInfoExecuteResponder, \

View File

@@ -1,19 +1,47 @@
com.tencent.supersonic.chat.api.component.SchemaMapper=\
com.tencent.supersonic.chat.mapper.HanlpDictMapper
com.tencent.supersonic.chat.mapper.EmbeddingMapper, \
com.tencent.supersonic.chat.mapper.HanlpDictMapper, \
com.tencent.supersonic.chat.mapper.FuzzyNameMapper, \
com.tencent.supersonic.chat.mapper.QueryFilterMapper, \
com.tencent.supersonic.chat.mapper.EntityMapper
com.tencent.supersonic.chat.api.component.SemanticParser=\
com.tencent.supersonic.chat.parser.rule.RuleBasedParser, \
com.tencent.supersonic.chat.parser.llm.interpret.MetricInterpretParser
# com.tencent.supersonic.chat.parser.llm.DSLQueryFunction
com.tencent.supersonic.chat.parser.llm.s2sql.LLMS2SQLParser, \
com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingBasedParser, \
com.tencent.supersonic.chat.parser.plugin.function.FunctionBasedParser, \
com.tencent.supersonic.chat.parser.QueryTypeParser
com.tencent.supersonic.chat.api.component.QueryProcessor=\
com.tencent.supersonic.chat.application.processor.SemanticQueryProcessor
com.tencent.supersonic.chat.api.component.SemanticCorrector=\
com.tencent.supersonic.chat.corrector.SchemaCorrector, \
com.tencent.supersonic.chat.corrector.SelectCorrector, \
com.tencent.supersonic.chat.corrector.WhereCorrector, \
com.tencent.supersonic.chat.corrector.GroupByCorrector, \
com.tencent.supersonic.chat.corrector.HavingCorrector
com.tencent.supersonic.chat.api.component.SemanticInterpreter=\
com.tencent.supersonic.knowledge.semantic.LocalSemanticInterpreter
com.tencent.supersonic.chat.application.query.DomainResolver=\
com.tencent.supersonic.chat.application.query.HeuristicDomainResolver
com.tencent.supersonic.chat.parser.llm.s2sql.ModelResolver=\
com.tencent.supersonic.chat.parser.llm.s2sql.HeuristicModelResolver
com.tencent.supersonic.auth.authentication.interceptor.AuthenticationInterceptor=\
com.tencent.supersonic.auth.authentication.interceptor.DefaultAuthenticationInterceptor
com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\
com.tencent.supersonic.auth.authentication.adaptor.DefaultUserAdaptor
com.tencent.supersonic.chat.postprocessor.PostProcessor=\
com.tencent.supersonic.chat.postprocessor.MetricCheckPostProcessor, \
com.tencent.supersonic.chat.postprocessor.ParseInfoUpdateProcessor
com.tencent.supersonic.chat.responder.parse.ParseResponder=\
com.tencent.supersonic.chat.responder.parse.QueryRankParseResponder, \
com.tencent.supersonic.chat.responder.parse.EntityInfoParseResponder, \
com.tencent.supersonic.chat.responder.parse.SqlInfoParseResponder, \
com.tencent.supersonic.chat.responder.parse.ParseTimeParseResponder, \
com.tencent.supersonic.chat.responder.parse.ParseRespBuildParseResponder
com.tencent.supersonic.chat.responder.execute.ExecuteResponder=\
com.tencent.supersonic.chat.responder.execute.EntityInfoExecuteResponder, \
com.tencent.supersonic.chat.responder.execute.SimilarMetricExecuteResponder