From ba55ecb31eacdc60991521085717640eef1aed07 Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Mon, 19 Aug 2024 01:21:10 +0800 Subject: [PATCH] (improvement)(chat) Make corrections and pass the data date format to the large model. (#1583) --- .../common/pojo/DimensionConstants.java | 1 + .../headless/api/pojo/DataSetSchema.java | 10 +++ .../headless/api/pojo/SchemaElement.java | 33 ++++++++-- .../chat/corrector/S2SqlDateHelper.java | 66 +++++++++++++------ .../chat/corrector/TimeCorrector.java | 14 ++-- .../chat/parser/llm/LLMRequestService.java | 65 ++++++++++-------- .../chat/query/rule/RuleSemanticQuery.java | 6 +- .../server/utils/DataSetSchemaBuilder.java | 1 + 8 files changed, 136 insertions(+), 60 deletions(-) diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/DimensionConstants.java b/common/src/main/java/com/tencent/supersonic/common/pojo/DimensionConstants.java index 346ff00d9..238309d70 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/DimensionConstants.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/DimensionConstants.java @@ -4,6 +4,7 @@ public class DimensionConstants { public static final String DIMENSION_TIME_FORMAT = "time_format"; + public static final String DIMENSION_TYPE = "dimension_type"; } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DataSetSchema.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DataSetSchema.java index 2e2bc9b35..aa3624536 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DataSetSchema.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DataSetSchema.java @@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.api.pojo; import lombok.Data; import org.apache.commons.collections.CollectionUtils; +import org.apache.commons.lang3.StringUtils; import java.util.ArrayList; import java.util.HashSet; @@ -126,4 +127,13 @@ public class DataSetSchema { return dimensions.stream().anyMatch(SchemaElement::containsPartitionTime); } + public String getPartitionTimeFormat() { + for (SchemaElement dimension : dimensions) { + String partitionTimeFormat = dimension.getPartitionTimeFormat(); + if (StringUtils.isNotBlank(partitionTimeFormat)) { + return partitionTimeFormat; + } + } + return null; + } } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java index 7beb729af..4710229e6 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java @@ -3,16 +3,18 @@ package com.tencent.supersonic.headless.api.pojo; import com.google.common.base.Objects; import com.tencent.supersonic.common.pojo.DimensionConstants; import com.tencent.supersonic.headless.api.pojo.enums.DimensionType; -import java.io.Serializable; -import java.util.HashMap; -import java.util.List; -import java.util.Map; import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; import lombok.Getter; import lombok.NoArgsConstructor; import org.apache.commons.collections4.MapUtils; +import org.apache.commons.lang3.StringUtils; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.List; +import java.util.Map; @Data @Getter @@ -65,8 +67,29 @@ public class SchemaElement implements Serializable { if (MapUtils.isEmpty(extInfo)) { return false; } - DimensionType dimensionTYpe = (DimensionType) extInfo.get(DimensionConstants.DIMENSION_TYPE); + Object o = extInfo.get(DimensionConstants.DIMENSION_TYPE); + DimensionType dimensionTYpe = null; + if (o instanceof DimensionType) { + dimensionTYpe = (DimensionType) o; + } + if (o instanceof String) { + dimensionTYpe = DimensionType.valueOf((String) o); + } return DimensionType.isPartitionTime(dimensionTYpe); } + public String getTimeFormat() { + if (MapUtils.isEmpty(extInfo)) { + return null; + } + return (String) extInfo.get(DimensionConstants.DIMENSION_TIME_FORMAT); + } + + public String getPartitionTimeFormat() { + String timeFormat = getTimeFormat(); + if (StringUtils.isNotBlank(timeFormat) && containsPartitionTime()) { + return timeFormat; + } + return null; + } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/S2SqlDateHelper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/S2SqlDateHelper.java index 93df8605d..98b167cdf 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/S2SqlDateHelper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/S2SqlDateHelper.java @@ -1,14 +1,17 @@ package com.tencent.supersonic.headless.chat.corrector; +import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum; import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.pojo.enums.TimeMode; -import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum; import com.tencent.supersonic.common.util.DateUtils; import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig; import com.tencent.supersonic.headless.chat.ChatQueryContext; +import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; +import java.text.SimpleDateFormat; +import java.util.Date; import java.util.Objects; public class S2SqlDateHelper { @@ -23,7 +26,8 @@ public class S2SqlDateHelper { return defaultDate; } TimeDefaultConfig tagTypeTimeDefaultConfig = dataSetSchema.getTagTypeTimeDefaultConfig(); - return getDefaultDate(defaultDate, tagTypeTimeDefaultConfig).getLeft(); + String partitionTimeFormat = dataSetSchema.getPartitionTimeFormat(); + return getDefaultDate(defaultDate, tagTypeTimeDefaultConfig, partitionTimeFormat).getLeft(); } public static Pair getStartEndDate(ChatQueryContext chatQueryContext, Long dataSetId, @@ -40,33 +44,53 @@ public class S2SqlDateHelper { if (QueryType.DETAIL.equals(queryType) && defaultConfig.getUnit() >= 0) { defaultConfig = dataSetSchema.getTagTypeTimeDefaultConfig(); } - return getDefaultDate(defaultDate, defaultConfig); + String partitionTimeFormat = dataSetSchema.getPartitionTimeFormat(); + return getDefaultDate(defaultDate, defaultConfig, partitionTimeFormat); } - private static Pair getDefaultDate(String defaultDate, TimeDefaultConfig defaultConfig) { - if (Objects.isNull(defaultConfig)) { + private static Pair getDefaultDate(String defaultDate, + TimeDefaultConfig defaultConfig, + String partitionTimeFormat) { + if (defaultConfig == null) { return Pair.of(null, null); } Integer unit = defaultConfig.getUnit(); + if (unit == null) { + return Pair.of(defaultDate, defaultDate); + } + + // If the unit is set to less than 0, then do not add relative date. + if (unit < 0) { + return Pair.of(null, null); + } + String period = defaultConfig.getPeriod(); TimeMode timeMode = defaultConfig.getTimeMode(); - if (Objects.nonNull(unit)) { - // If the unit is set to less than 0, then do not add relative date. - if (unit < 0) { - return Pair.of(null, null); - } - DatePeriodEnum datePeriodEnum = DatePeriodEnum.get(period); - String startDate = DateUtils.getBeforeDate(unit, datePeriodEnum); - String endDate = DateUtils.getBeforeDate(0, DatePeriodEnum.DAY); - if (unit == 0) { - endDate = startDate; - } - if (TimeMode.LAST.equals(timeMode)) { - endDate = startDate; - } - return Pair.of(startDate, endDate); + DatePeriodEnum datePeriodEnum = DatePeriodEnum.get(period); + + String startDate = DateUtils.getBeforeDate(unit, datePeriodEnum); + String endDate = DateUtils.getBeforeDate(0, DatePeriodEnum.DAY); + + if (unit == 0 || TimeMode.LAST.equals(timeMode)) { + endDate = startDate; } - return Pair.of(defaultDate, defaultDate); + if (StringUtils.isNotBlank(partitionTimeFormat)) { + startDate = formatDate(startDate, partitionTimeFormat); + endDate = formatDate(endDate, partitionTimeFormat); + } + return Pair.of(startDate, endDate); } + private static String formatDate(String dateStr, String format) { + try { + // Assuming the input date format is "yyyy-MM-dd" + SimpleDateFormat inputFormat = new SimpleDateFormat(DateUtils.DATE_FORMAT); + Date date = inputFormat.parse(dateStr); + SimpleDateFormat outputFormat = new SimpleDateFormat(format); + return outputFormat.format(date); + } catch (Exception e) { + // Handle the exception, maybe log it and return the original dateStr + return dateStr; + } + } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java index 444554c84..115c417f2 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java @@ -11,9 +11,6 @@ import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.chat.ChatQueryContext; -import java.util.HashSet; -import java.util.List; -import java.util.Set; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.expression.Expression; @@ -22,6 +19,10 @@ import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; import org.springframework.util.CollectionUtils; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + /** * Perform SQL corrections on the time in S2SQL. */ @@ -60,8 +61,11 @@ public class TimeCorrector extends BaseSemanticCorrector { if (isValidDateRange(startEndDate)) { correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL); String dateChName = TimeDimensionEnum.DAY.getChName(); + String startDateLeft = startEndDate.getLeft(); + String endDateRight = startEndDate.getRight(); + String condExpr = String.format(" ( %s >= '%s' and %s <= '%s' )", - dateChName, startEndDate.getLeft(), dateChName, startEndDate.getRight()); + dateChName, startDateLeft, dateChName, endDateRight); correctS2SQL = addConditionToSQL(correctS2SQL, condExpr); } } @@ -69,7 +73,7 @@ public class TimeCorrector extends BaseSemanticCorrector { } private boolean containsPartitionDimensions(ChatQueryContext chatQueryContext, - SemanticParseInfo semanticParseInfo) { + SemanticParseInfo semanticParseInfo) { Long dataSetId = semanticParseInfo.getDataSetId(); SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema(); DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java index 58b3f98ae..4a51dc36c 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java @@ -16,6 +16,7 @@ import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.headless.chat.utils.ComponentFactory; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; @@ -76,7 +77,7 @@ public class LLMRequestService { llmSchema.setDataSetName(dataSetIdToName.get(dataSetId)); llmSchema.setDomainName(dataSetIdToName.get(dataSetId)); - List fieldNameList = getMatchedFieldNames(queryCtx, dataSetId); + Set fieldNameList = getMatchedFieldNames(queryCtx, dataSetId); if (Objects.nonNull(semanticSchema.getDataSetSchemaMap()) && Objects.nonNull(semanticSchema.getDataSetSchemaMap().get(dataSetId))) { TimeDefaultConfig timeDefaultConfig = semanticSchema.getDataSetSchemaMap() @@ -87,14 +88,14 @@ public class LLMRequestService { fieldNameList.add(TimeDimensionEnum.DAY.getChName()); } } - llmSchema.setFieldNameList(fieldNameList); + llmSchema.setFieldNameList(new ArrayList<>(fieldNameList)); llmSchema.setMetrics(getMatchedMetrics(queryCtx, dataSetId)); llmSchema.setDimensions(getMatchedDimensions(queryCtx, dataSetId)); llmSchema.setTerms(getTerms(queryCtx, dataSetId)); llmReq.setSchema(llmSchema); - String priorExts = getPriorExts(queryCtx, fieldNameList); + String priorExts = getPriorExts(queryCtx, new ArrayList<>(fieldNameList)); llmReq.setPriorExts(priorExts); List linking = new ArrayList<>(); @@ -144,27 +145,41 @@ public class LLMRequestService { private String getPriorExts(ChatQueryContext queryContext, List fieldNameList) { StringBuilder extraInfoSb = new StringBuilder(); SemanticSchema semanticSchema = queryContext.getSemanticSchema(); - Map fieldNameToDataFormatType = semanticSchema.getMetrics() - .stream().filter(metricSchemaResp -> Objects.nonNull(metricSchemaResp.getDataFormatType())) - .flatMap(metricSchemaResp -> { - Set> result = new HashSet<>(); - String dataFormatType = metricSchemaResp.getDataFormatType(); - result.add(Pair.of(metricSchemaResp.getName(), dataFormatType)); - List aliasList = metricSchemaResp.getAlias(); - if (!CollectionUtils.isEmpty(aliasList)) { - for (String alias : aliasList) { - result.add(Pair.of(alias, dataFormatType)); - } - } - return result.stream(); - }).collect(Collectors.toMap(Pair::getLeft, Pair::getRight, (k1, k2) -> k1)); + // 获取字段名到数据格式类型的映射 + Map fieldNameToDataFormatType = semanticSchema.getMetrics().stream() + .filter(metric -> Objects.nonNull(metric.getDataFormatType())) + .flatMap(metric -> { + Set> fieldFormatPairs = new HashSet<>(); + String dataFormatType = metric.getDataFormatType(); + fieldFormatPairs.add(Pair.of(metric.getName(), dataFormatType)); + List aliasList = metric.getAlias(); + if (!CollectionUtils.isEmpty(aliasList)) { + aliasList.forEach(alias -> fieldFormatPairs.add(Pair.of(alias, dataFormatType))); + } + return fieldFormatPairs.stream(); + }) + .collect(Collectors.toMap(Pair::getLeft, Pair::getRight, (existing, replacement) -> existing)); + + Map fieldNameToDateFormat = semanticSchema.getDimensions().stream() + .filter(dimension -> StringUtils.isNotBlank(dimension.getTimeFormat())) + .collect(Collectors.toMap( + SchemaElement::getName, SchemaElement::getPartitionTimeFormat, (k1, k2) -> k1) + ); + + // 构建额外信息字符串 for (String fieldName : fieldNameList) { String dataFormatType = fieldNameToDataFormatType.get(fieldName); if (DataFormatTypeEnum.DECIMAL.getName().equalsIgnoreCase(dataFormatType) || DataFormatTypeEnum.PERCENT.getName().equalsIgnoreCase(dataFormatType)) { - String format = String.format("%s的计量单位是%s", fieldName, "小数; "); - extraInfoSb.append(format); + extraInfoSb.append(String.format("%s的计量单位是%s; ", fieldName, "小数")); + } + } + // 构建分区日期格式化信息 + for (String fieldName : fieldNameList) { + String timeFormat = fieldNameToDateFormat.get(fieldName); + if (StringUtils.isNotBlank(timeFormat)) { + extraInfoSb.append(String.format("%s的日期Format格式是%s; ", fieldName, timeFormat)); } } return extraInfoSb.toString(); @@ -195,8 +210,8 @@ public class LLMRequestService { protected Map getItemIdToName(ChatQueryContext queryCtx, Long dataSetId) { SemanticSchema semanticSchema = queryCtx.getSemanticSchema(); List elements = semanticSchema.getDimensions(dataSetId); - return elements.stream() - .collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2)); + return elements.stream().collect( + Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2)); } protected List getMatchedMetrics(ChatQueryContext queryCtx, Long dataSetId) { @@ -230,13 +245,13 @@ public class LLMRequestService { .collect(Collectors.toList()); } - protected List getMatchedFieldNames(ChatQueryContext queryCtx, Long dataSetId) { + protected Set getMatchedFieldNames(ChatQueryContext queryCtx, Long dataSetId) { Map itemIdToName = getItemIdToName(queryCtx, dataSetId); List matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId); if (CollectionUtils.isEmpty(matchedElements)) { - return new ArrayList<>(); + return new HashSet<>(); } - Set fieldNameList = matchedElements.stream() + return matchedElements.stream() .filter(schemaElementMatch -> { SchemaElementType elementType = schemaElementMatch.getElement().getType(); return SchemaElementType.METRIC.equals(elementType) @@ -252,7 +267,5 @@ public class LLMRequestService { return schemaElementMatch.getWord(); }) .collect(Collectors.toSet()); - - return new ArrayList<>(fieldNameList); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java index 73a024212..587fcd1bb 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java @@ -174,9 +174,9 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery { parseInfo.getDimensionFilters().add(dimensionFilter); } else { QueryFilter dimensionFilter = new QueryFilter(); - List vals = new ArrayList<>(); - entry.getValue().stream().forEach(i -> vals.add(i.getWord())); - dimensionFilter.setValue(vals); + List values = new ArrayList<>(); + entry.getValue().stream().forEach(i -> values.add(i.getWord())); + dimensionFilter.setValue(values); dimensionFilter.setBizName(dimension.getBizName()); dimensionFilter.setName(dimension.getName()); dimensionFilter.setOperator(FilterOperatorEnum.IN); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DataSetSchemaBuilder.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DataSetSchemaBuilder.java index f76ffc055..6c9713520 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DataSetSchemaBuilder.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DataSetSchemaBuilder.java @@ -170,6 +170,7 @@ public class DataSetSchemaBuilder { .type(SchemaElementType.DIMENSION) .build(); dimToAdd.getExtInfo().put(DimensionConstants.DIMENSION_TYPE, dim.getType()); + if (dim.isTimeDimension()) { String timeFormat = String.valueOf(dim.getExt().get(DimensionConstants.DIMENSION_TIME_FORMAT)); dimToAdd.getExtInfo().put(DimensionConstants.DIMENSION_TIME_FORMAT, timeFormat);