From 4193b84e837e2e64bf37bf658f01e369163c6153 Mon Sep 17 00:00:00 2001 From: LXW <1264174498@qq.com> Date: Sat, 21 Sep 2024 13:24:38 +0800 Subject: [PATCH] (fix)(chat) fix the issue that front-end filter time re-query does not take effect, when the partition time field name is not imp_date #1638 (#1694) Co-authored-by: lxwcodemonkey --- .../service/impl/ChatQueryServiceImpl.java | 46 ++++++++++--------- .../parser/llm/OnePassSCSqlGenStrategy.java | 3 +- 2 files changed, 26 insertions(+), 23 deletions(-) diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java index 4f0dbac0f..9a985fd0b 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java @@ -24,7 +24,6 @@ import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper; import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; -import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.util.BeanMapper; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.DateUtils; @@ -49,6 +48,15 @@ import com.tencent.supersonic.headless.chat.query.SemanticQuery; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery; import com.tencent.supersonic.headless.server.facade.service.ChatLayerService; import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.LongValue; @@ -65,16 +73,6 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import org.springframework.util.CollectionUtils; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; -import java.util.stream.Collectors; - @Slf4j @Service public class ChatQueryServiceImpl implements ChatQueryService { @@ -200,7 +198,6 @@ public class ChatQueryServiceImpl implements ChatQueryService { SemanticParseInfo parseInfo = chatManageService.getParseInfo(chatQueryDataReq.getQueryId(), parseId); parseInfo = mergeParseInfo(parseInfo, chatQueryDataReq); - parseInfo.setSqlInfo(new SqlInfo()); DataSetSchema dataSetSchema = semanticLayerService.getDataSetSchema(parseInfo.getDataSetId()); @@ -208,7 +205,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { semanticQuery.setParseInfo(parseInfo); if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())) { - handleLLMQueryMode(chatQueryDataReq, semanticQuery, user); + handleLLMQueryMode(chatQueryDataReq, semanticQuery, dataSetSchema, user); } else { handleRuleQueryMode(semanticQuery, dataSetSchema, user); } @@ -225,7 +222,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { } private void handleLLMQueryMode( - ChatQueryDataReq chatQueryDataReq, SemanticQuery semanticQuery, User user) + ChatQueryDataReq chatQueryDataReq, SemanticQuery semanticQuery, DataSetSchema dataSetSchema, User user) throws Exception { SemanticParseInfo parseInfo = semanticQuery.getParseInfo(); List fields = getFieldsFromSql(parseInfo); @@ -235,7 +232,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { replaceMetrics(parseInfo, metricToReplace); } else { log.info("llm begin revise filters!"); - String correctorSql = reviseCorrectS2SQL(chatQueryDataReq, parseInfo); + String correctorSql = reviseCorrectS2SQL(chatQueryDataReq, parseInfo, dataSetSchema); parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql); semanticQuery.setParseInfo(parseInfo); SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq(); @@ -261,6 +258,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { SemanticLayerService semanticService = ContextUtils.getBean(SemanticLayerService.class); EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, user); queryResult.setEntityInfo(entityInfo); + parseInfo.getSqlInfo().setQuerySQL(queryResult.getQuerySql()); return queryResult; } @@ -273,7 +271,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { return !oriFields.containsAll(metricNames); } - private String reviseCorrectS2SQL(ChatQueryDataReq queryData, SemanticParseInfo parseInfo) { + private String reviseCorrectS2SQL(ChatQueryDataReq queryData, SemanticParseInfo parseInfo, DataSetSchema dataSetSchema) { String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL(); log.info("correctorSql before replacing:{}", correctorSql); // get where filter and having filter @@ -294,6 +292,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { updateDateInfo( queryData, parseInfo, + dataSetSchema, filedNameToValueMap, whereExpressionList, addWhereConditions); @@ -361,6 +360,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { private Set updateDateInfo( ChatQueryDataReq queryData, SemanticParseInfo parseInfo, + DataSetSchema dataSetSchema, Map> filedNameToValueMap, List fieldExpressionList, List addConditions) { @@ -374,17 +374,18 @@ public class ChatQueryServiceImpl implements ChatQueryService { .setStartDate(DateUtils.getBeforeDate(queryData.getDateInfo().getUnit() + 1)); queryData.getDateInfo().setEndDate(DateUtils.getBeforeDate(1)); } + SchemaElement partitionDimension = dataSetSchema.getPartitionDimension(); // startDate equals to endDate for (FieldExpression fieldExpression : fieldExpressionList) { - if (TimeDimensionEnum.DAY.getChName().equals(fieldExpression.getFieldName())) { + if (partitionDimension.getName().equals(fieldExpression.getFieldName())) { // first remove,then add - removeFieldNames.add(TimeDimensionEnum.DAY.getChName()); + removeFieldNames.add(partitionDimension.getName()); GreaterThanEquals greaterThanEquals = new GreaterThanEquals(); addTimeFilters( - queryData.getDateInfo().getStartDate(), greaterThanEquals, addConditions); + queryData.getDateInfo().getStartDate(), greaterThanEquals, addConditions, partitionDimension); MinorThanEquals minorThanEquals = new MinorThanEquals(); addTimeFilters( - queryData.getDateInfo().getEndDate(), minorThanEquals, addConditions); + queryData.getDateInfo().getEndDate(), minorThanEquals, addConditions, partitionDimension); break; } } @@ -414,8 +415,8 @@ public class ChatQueryServiceImpl implements ChatQueryService { } private void addTimeFilters( - String date, T comparisonExpression, List addConditions) { - Column column = new Column(TimeDimensionEnum.DAY.getChName()); + String date, T comparisonExpression, List addConditions, SchemaElement partitionDimension) { + Column column = new Column(partitionDimension.getName()); StringValue stringValue = new StringValue(date); comparisonExpression.setLeftExpression(column); comparisonExpression.setRightExpression(stringValue); @@ -548,6 +549,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { if (Objects.nonNull(queryData.getDateInfo())) { parseInfo.setDateInfo(queryData.getDateInfo()); } + parseInfo.setSqlInfo(new SqlInfo()); return parseInfo; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java index 8145da458..9db08ab6b 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java @@ -36,7 +36,8 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy { + "\n2.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator." + "\n3.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`." + "\n4.DO NOT calculate date range using functions." - + "\n5.DO NOT miss the AGGREGATE operator of metrics, always add it as needed." + + "\n5.DO NOT calculate date range using DATE_SUB." + + "\n6.DO NOT miss the AGGREGATE operator of metrics, always add it as needed." + "\n#Exemplars:\n{{exemplar}}" + "#Question:\nQuestion:{{question}},Schema:{{schema}},SideInfo:{{information}}";