(fix)(chat) fix the issue that front-end filter time re-query does not take effect, when the partition time field name is not imp_date #1638 (#1694)

Co-authored-by: lxwcodemonkey
This commit is contained in:
LXW
2024-09-21 13:24:38 +08:00
committed by GitHub
parent a18c340a64
commit 4193b84e83
2 changed files with 26 additions and 23 deletions

View File

@@ -24,7 +24,6 @@ import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper; import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.BeanMapper; import com.tencent.supersonic.common.util.BeanMapper;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateUtils; import com.tencent.supersonic.common.util.DateUtils;
@@ -49,6 +48,15 @@ import com.tencent.supersonic.headless.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService; import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService; import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.LongValue; import net.sf.jsqlparser.expression.LongValue;
@@ -65,16 +73,6 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
@Slf4j @Slf4j
@Service @Service
public class ChatQueryServiceImpl implements ChatQueryService { public class ChatQueryServiceImpl implements ChatQueryService {
@@ -200,7 +198,6 @@ public class ChatQueryServiceImpl implements ChatQueryService {
SemanticParseInfo parseInfo = SemanticParseInfo parseInfo =
chatManageService.getParseInfo(chatQueryDataReq.getQueryId(), parseId); chatManageService.getParseInfo(chatQueryDataReq.getQueryId(), parseId);
parseInfo = mergeParseInfo(parseInfo, chatQueryDataReq); parseInfo = mergeParseInfo(parseInfo, chatQueryDataReq);
parseInfo.setSqlInfo(new SqlInfo());
DataSetSchema dataSetSchema = DataSetSchema dataSetSchema =
semanticLayerService.getDataSetSchema(parseInfo.getDataSetId()); semanticLayerService.getDataSetSchema(parseInfo.getDataSetId());
@@ -208,7 +205,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
semanticQuery.setParseInfo(parseInfo); semanticQuery.setParseInfo(parseInfo);
if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())) { if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())) {
handleLLMQueryMode(chatQueryDataReq, semanticQuery, user); handleLLMQueryMode(chatQueryDataReq, semanticQuery, dataSetSchema, user);
} else { } else {
handleRuleQueryMode(semanticQuery, dataSetSchema, user); handleRuleQueryMode(semanticQuery, dataSetSchema, user);
} }
@@ -225,7 +222,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
} }
private void handleLLMQueryMode( private void handleLLMQueryMode(
ChatQueryDataReq chatQueryDataReq, SemanticQuery semanticQuery, User user) ChatQueryDataReq chatQueryDataReq, SemanticQuery semanticQuery, DataSetSchema dataSetSchema, User user)
throws Exception { throws Exception {
SemanticParseInfo parseInfo = semanticQuery.getParseInfo(); SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
List<String> fields = getFieldsFromSql(parseInfo); List<String> fields = getFieldsFromSql(parseInfo);
@@ -235,7 +232,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
replaceMetrics(parseInfo, metricToReplace); replaceMetrics(parseInfo, metricToReplace);
} else { } else {
log.info("llm begin revise filters!"); log.info("llm begin revise filters!");
String correctorSql = reviseCorrectS2SQL(chatQueryDataReq, parseInfo); String correctorSql = reviseCorrectS2SQL(chatQueryDataReq, parseInfo, dataSetSchema);
parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql); parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql);
semanticQuery.setParseInfo(parseInfo); semanticQuery.setParseInfo(parseInfo);
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq(); SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
@@ -261,6 +258,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
SemanticLayerService semanticService = ContextUtils.getBean(SemanticLayerService.class); SemanticLayerService semanticService = ContextUtils.getBean(SemanticLayerService.class);
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, user); EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, user);
queryResult.setEntityInfo(entityInfo); queryResult.setEntityInfo(entityInfo);
parseInfo.getSqlInfo().setQuerySQL(queryResult.getQuerySql());
return queryResult; return queryResult;
} }
@@ -273,7 +271,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
return !oriFields.containsAll(metricNames); return !oriFields.containsAll(metricNames);
} }
private String reviseCorrectS2SQL(ChatQueryDataReq queryData, SemanticParseInfo parseInfo) { private String reviseCorrectS2SQL(ChatQueryDataReq queryData, SemanticParseInfo parseInfo, DataSetSchema dataSetSchema) {
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL(); String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
log.info("correctorSql before replacing:{}", correctorSql); log.info("correctorSql before replacing:{}", correctorSql);
// get where filter and having filter // get where filter and having filter
@@ -294,6 +292,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
updateDateInfo( updateDateInfo(
queryData, queryData,
parseInfo, parseInfo,
dataSetSchema,
filedNameToValueMap, filedNameToValueMap,
whereExpressionList, whereExpressionList,
addWhereConditions); addWhereConditions);
@@ -361,6 +360,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
private Set<String> updateDateInfo( private Set<String> updateDateInfo(
ChatQueryDataReq queryData, ChatQueryDataReq queryData,
SemanticParseInfo parseInfo, SemanticParseInfo parseInfo,
DataSetSchema dataSetSchema,
Map<String, Map<String, String>> filedNameToValueMap, Map<String, Map<String, String>> filedNameToValueMap,
List<FieldExpression> fieldExpressionList, List<FieldExpression> fieldExpressionList,
List<Expression> addConditions) { List<Expression> addConditions) {
@@ -374,17 +374,18 @@ public class ChatQueryServiceImpl implements ChatQueryService {
.setStartDate(DateUtils.getBeforeDate(queryData.getDateInfo().getUnit() + 1)); .setStartDate(DateUtils.getBeforeDate(queryData.getDateInfo().getUnit() + 1));
queryData.getDateInfo().setEndDate(DateUtils.getBeforeDate(1)); queryData.getDateInfo().setEndDate(DateUtils.getBeforeDate(1));
} }
SchemaElement partitionDimension = dataSetSchema.getPartitionDimension();
// startDate equals to endDate // startDate equals to endDate
for (FieldExpression fieldExpression : fieldExpressionList) { for (FieldExpression fieldExpression : fieldExpressionList) {
if (TimeDimensionEnum.DAY.getChName().equals(fieldExpression.getFieldName())) { if (partitionDimension.getName().equals(fieldExpression.getFieldName())) {
// first remove,then add // first remove,then add
removeFieldNames.add(TimeDimensionEnum.DAY.getChName()); removeFieldNames.add(partitionDimension.getName());
GreaterThanEquals greaterThanEquals = new GreaterThanEquals(); GreaterThanEquals greaterThanEquals = new GreaterThanEquals();
addTimeFilters( addTimeFilters(
queryData.getDateInfo().getStartDate(), greaterThanEquals, addConditions); queryData.getDateInfo().getStartDate(), greaterThanEquals, addConditions, partitionDimension);
MinorThanEquals minorThanEquals = new MinorThanEquals(); MinorThanEquals minorThanEquals = new MinorThanEquals();
addTimeFilters( addTimeFilters(
queryData.getDateInfo().getEndDate(), minorThanEquals, addConditions); queryData.getDateInfo().getEndDate(), minorThanEquals, addConditions, partitionDimension);
break; break;
} }
} }
@@ -414,8 +415,8 @@ public class ChatQueryServiceImpl implements ChatQueryService {
} }
private <T extends ComparisonOperator> void addTimeFilters( private <T extends ComparisonOperator> void addTimeFilters(
String date, T comparisonExpression, List<Expression> addConditions) { String date, T comparisonExpression, List<Expression> addConditions, SchemaElement partitionDimension) {
Column column = new Column(TimeDimensionEnum.DAY.getChName()); Column column = new Column(partitionDimension.getName());
StringValue stringValue = new StringValue(date); StringValue stringValue = new StringValue(date);
comparisonExpression.setLeftExpression(column); comparisonExpression.setLeftExpression(column);
comparisonExpression.setRightExpression(stringValue); comparisonExpression.setRightExpression(stringValue);
@@ -548,6 +549,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
if (Objects.nonNull(queryData.getDateInfo())) { if (Objects.nonNull(queryData.getDateInfo())) {
parseInfo.setDateInfo(queryData.getDateInfo()); parseInfo.setDateInfo(queryData.getDateInfo());
} }
parseInfo.setSqlInfo(new SqlInfo());
return parseInfo; return parseInfo;
} }

View File

@@ -36,7 +36,8 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
+ "\n2.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator." + "\n2.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator."
+ "\n3.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`." + "\n3.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`."
+ "\n4.DO NOT calculate date range using functions." + "\n4.DO NOT calculate date range using functions."
+ "\n5.DO NOT miss the AGGREGATE operator of metrics, always add it as needed." + "\n5.DO NOT calculate date range using DATE_SUB."
+ "\n6.DO NOT miss the AGGREGATE operator of metrics, always add it as needed."
+ "\n#Exemplars:\n{{exemplar}}" + "\n#Exemplars:\n{{exemplar}}"
+ "#Question:\nQuestion:{{question}},Schema:{{schema}},SideInfo:{{information}}"; + "#Question:\nQuestion:{{question}},Schema:{{schema}},SideInfo:{{information}}";