[fix][chat]Fix adjusted filters don't take effect.

This commit is contained in:
jerryjzhang
2024-12-08 13:27:34 +08:00
parent 4bd521020a
commit 43ba1a0ba6

View File

@@ -18,11 +18,7 @@ import com.tencent.supersonic.chat.server.service.ChatManageService;
import com.tencent.supersonic.chat.server.service.ChatQueryService;
import com.tencent.supersonic.chat.server.util.ComponentFactory;
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.jsqlparser.FieldExpression;
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.jsqlparser.*;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.util.DateUtils;
@@ -48,11 +44,7 @@ import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
import net.sf.jsqlparser.expression.operators.relational.ParenthesedExpressionList;
import net.sf.jsqlparser.expression.operators.relational.*;
import net.sf.jsqlparser.schema.Column;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
@@ -60,14 +52,7 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.*;
import java.util.stream.Collectors;
@Slf4j
@@ -210,20 +195,22 @@ public class ChatQueryServiceImpl implements ChatQueryService {
private void handleLLMQueryMode(ChatQueryDataReq chatQueryDataReq, SemanticQuery semanticQuery,
DataSetSchema dataSetSchema, User user) throws Exception {
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
List<String> fields = getFieldsFromSql(parseInfo);
if (checkMetricReplace(fields, chatQueryDataReq.getMetrics())) {
log.info("llm begin replace metrics!");
String rebuiltS2SQL;
if (checkMetricReplace(chatQueryDataReq, parseInfo)) {
log.info("rebuild S2SQL with adjusted metrics!");
SchemaElement metricToReplace = chatQueryDataReq.getMetrics().iterator().next();
replaceMetrics(parseInfo, metricToReplace);
rebuiltS2SQL = replaceMetrics(parseInfo, metricToReplace);
} else {
log.info("llm begin revise filters!");
String correctorSql = reviseCorrectS2SQL(chatQueryDataReq, parseInfo, dataSetSchema);
parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql);
semanticQuery.setParseInfo(parseInfo);
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
SemanticTranslateResp explain = semanticLayerService.translate(semanticQueryReq, user);
parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL());
log.info("rebuild S2SQL with adjusted filters!");
rebuiltS2SQL = replaceFilters(chatQueryDataReq, parseInfo, dataSetSchema);
}
// reset SqlInfo and request re-translation
parseInfo.getSqlInfo().setCorrectedS2SQL(rebuiltS2SQL);
parseInfo.getSqlInfo().setParsedS2SQL(rebuiltS2SQL);
parseInfo.getSqlInfo().setQuerySQL(null);
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
SemanticTranslateResp explain = semanticLayerService.translate(semanticQueryReq, user);
parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL());
}
private void handleRuleQueryMode(SemanticQuery semanticQuery, DataSetSchema dataSetSchema,
@@ -243,7 +230,9 @@ public class ChatQueryServiceImpl implements ChatQueryService {
return queryResult;
}
private boolean checkMetricReplace(List<String> oriFields, Set<SchemaElement> metrics) {
private boolean checkMetricReplace(ChatQueryDataReq chatQueryDataReq, SemanticParseInfo parseInfo) {
List<String> oriFields = getFieldsFromSql(parseInfo);
Set<SchemaElement> metrics = chatQueryDataReq.getMetrics();
if (CollectionUtils.isEmpty(oriFields) || CollectionUtils.isEmpty(metrics)) {
return false;
}
@@ -252,8 +241,8 @@ public class ChatQueryServiceImpl implements ChatQueryService {
return !oriFields.containsAll(metricNames);
}
private String reviseCorrectS2SQL(ChatQueryDataReq queryData, SemanticParseInfo parseInfo,
DataSetSchema dataSetSchema) {
private String replaceFilters(ChatQueryDataReq queryData, SemanticParseInfo parseInfo,
DataSetSchema dataSetSchema) {
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
log.info("correctorSql before replacing:{}", correctorSql);
// get where filter and having filter
@@ -290,7 +279,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
return correctorSql;
}
private void replaceMetrics(SemanticParseInfo parseInfo, SchemaElement metric) {
private String replaceMetrics(SemanticParseInfo parseInfo, SchemaElement metric) {
List<String> oriMetrics = parseInfo.getMetrics().stream().map(SchemaElement::getName)
.collect(Collectors.toList());
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
@@ -302,7 +291,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
correctorSql = SqlReplaceHelper.replaceAggFields(correctorSql, fieldMap);
}
log.info("after replaceMetrics:{}", correctorSql);
parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql);
return correctorSql;
}
private QueryResult doExecution(SemanticQueryReq semanticQueryReq, String queryMode, User user)
@@ -477,6 +466,9 @@ public class ChatQueryServiceImpl implements ChatQueryService {
}
private void mergeParseInfo(SemanticParseInfo parseInfo, ChatQueryDataReq queryData) {
if (Objects.nonNull(queryData.getDateInfo())) {
parseInfo.setDateInfo(queryData.getDateInfo());
}
if (LLMSqlQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
return;
}
@@ -492,9 +484,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
if (!CollectionUtils.isEmpty(queryData.getMetricFilters())) {
parseInfo.setMetricFilters(queryData.getMetricFilters());
}
if (Objects.nonNull(queryData.getDateInfo())) {
parseInfo.setDateInfo(queryData.getDateInfo());
}
parseInfo.setSqlInfo(new SqlInfo());
}