diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java index 9c95d1fc3..92a1b6a4d 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java @@ -18,11 +18,7 @@ import com.tencent.supersonic.chat.server.service.ChatManageService; import com.tencent.supersonic.chat.server.service.ChatQueryService; import com.tencent.supersonic.chat.server.util.ComponentFactory; import com.tencent.supersonic.chat.server.util.QueryReqConverter; -import com.tencent.supersonic.common.jsqlparser.FieldExpression; -import com.tencent.supersonic.common.jsqlparser.SqlAddHelper; -import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper; -import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper; -import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; +import com.tencent.supersonic.common.jsqlparser.*; import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import com.tencent.supersonic.common.util.DateUtils; @@ -48,11 +44,7 @@ import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.LongValue; import net.sf.jsqlparser.expression.StringValue; -import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; -import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals; -import net.sf.jsqlparser.expression.operators.relational.InExpression; -import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals; -import net.sf.jsqlparser.expression.operators.relational.ParenthesedExpressionList; +import net.sf.jsqlparser.expression.operators.relational.*; import net.sf.jsqlparser.schema.Column; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; @@ -60,14 +52,7 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import org.springframework.util.CollectionUtils; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; +import java.util.*; import java.util.stream.Collectors; @Slf4j @@ -210,20 +195,22 @@ public class ChatQueryServiceImpl implements ChatQueryService { private void handleLLMQueryMode(ChatQueryDataReq chatQueryDataReq, SemanticQuery semanticQuery, DataSetSchema dataSetSchema, User user) throws Exception { SemanticParseInfo parseInfo = semanticQuery.getParseInfo(); - List fields = getFieldsFromSql(parseInfo); - if (checkMetricReplace(fields, chatQueryDataReq.getMetrics())) { - log.info("llm begin replace metrics!"); + String rebuiltS2SQL; + if (checkMetricReplace(chatQueryDataReq, parseInfo)) { + log.info("rebuild S2SQL with adjusted metrics!"); SchemaElement metricToReplace = chatQueryDataReq.getMetrics().iterator().next(); - replaceMetrics(parseInfo, metricToReplace); + rebuiltS2SQL = replaceMetrics(parseInfo, metricToReplace); } else { - log.info("llm begin revise filters!"); - String correctorSql = reviseCorrectS2SQL(chatQueryDataReq, parseInfo, dataSetSchema); - parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql); - semanticQuery.setParseInfo(parseInfo); - SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq(); - SemanticTranslateResp explain = semanticLayerService.translate(semanticQueryReq, user); - parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL()); + log.info("rebuild S2SQL with adjusted filters!"); + rebuiltS2SQL = replaceFilters(chatQueryDataReq, parseInfo, dataSetSchema); } + // reset SqlInfo and request re-translation + parseInfo.getSqlInfo().setCorrectedS2SQL(rebuiltS2SQL); + parseInfo.getSqlInfo().setParsedS2SQL(rebuiltS2SQL); + parseInfo.getSqlInfo().setQuerySQL(null); + SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq(); + SemanticTranslateResp explain = semanticLayerService.translate(semanticQueryReq, user); + parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL()); } private void handleRuleQueryMode(SemanticQuery semanticQuery, DataSetSchema dataSetSchema, @@ -243,7 +230,9 @@ public class ChatQueryServiceImpl implements ChatQueryService { return queryResult; } - private boolean checkMetricReplace(List oriFields, Set metrics) { + private boolean checkMetricReplace(ChatQueryDataReq chatQueryDataReq, SemanticParseInfo parseInfo) { + List oriFields = getFieldsFromSql(parseInfo); + Set metrics = chatQueryDataReq.getMetrics(); if (CollectionUtils.isEmpty(oriFields) || CollectionUtils.isEmpty(metrics)) { return false; } @@ -252,8 +241,8 @@ public class ChatQueryServiceImpl implements ChatQueryService { return !oriFields.containsAll(metricNames); } - private String reviseCorrectS2SQL(ChatQueryDataReq queryData, SemanticParseInfo parseInfo, - DataSetSchema dataSetSchema) { + private String replaceFilters(ChatQueryDataReq queryData, SemanticParseInfo parseInfo, + DataSetSchema dataSetSchema) { String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL(); log.info("correctorSql before replacing:{}", correctorSql); // get where filter and having filter @@ -290,7 +279,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { return correctorSql; } - private void replaceMetrics(SemanticParseInfo parseInfo, SchemaElement metric) { + private String replaceMetrics(SemanticParseInfo parseInfo, SchemaElement metric) { List oriMetrics = parseInfo.getMetrics().stream().map(SchemaElement::getName) .collect(Collectors.toList()); String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL(); @@ -302,7 +291,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { correctorSql = SqlReplaceHelper.replaceAggFields(correctorSql, fieldMap); } log.info("after replaceMetrics:{}", correctorSql); - parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql); + return correctorSql; } private QueryResult doExecution(SemanticQueryReq semanticQueryReq, String queryMode, User user) @@ -477,6 +466,9 @@ public class ChatQueryServiceImpl implements ChatQueryService { } private void mergeParseInfo(SemanticParseInfo parseInfo, ChatQueryDataReq queryData) { + if (Objects.nonNull(queryData.getDateInfo())) { + parseInfo.setDateInfo(queryData.getDateInfo()); + } if (LLMSqlQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) { return; } @@ -492,9 +484,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { if (!CollectionUtils.isEmpty(queryData.getMetricFilters())) { parseInfo.setMetricFilters(queryData.getMetricFilters()); } - if (Objects.nonNull(queryData.getDateInfo())) { - parseInfo.setDateInfo(queryData.getDateInfo()); - } + parseInfo.setSqlInfo(new SqlInfo()); }