(improvement)(chat) Make corrections and pass the data date format to the large model. (#1583)

This commit is contained in:
lexluo09
2024-08-19 01:21:10 +08:00
committed by GitHub
parent 10a5e485cb
commit ba55ecb31e
8 changed files with 136 additions and 60 deletions

View File

@@ -4,6 +4,7 @@ public class DimensionConstants {
public static final String DIMENSION_TIME_FORMAT = "time_format";
public static final String DIMENSION_TYPE = "dimension_type";
}

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.api.pojo;
import lombok.Data;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import java.util.ArrayList;
import java.util.HashSet;
@@ -126,4 +127,13 @@ public class DataSetSchema {
return dimensions.stream().anyMatch(SchemaElement::containsPartitionTime);
}
public String getPartitionTimeFormat() {
for (SchemaElement dimension : dimensions) {
String partitionTimeFormat = dimension.getPartitionTimeFormat();
if (StringUtils.isNotBlank(partitionTimeFormat)) {
return partitionTimeFormat;
}
}
return null;
}
}

View File

@@ -3,16 +3,18 @@ package com.tencent.supersonic.headless.api.pojo;
import com.google.common.base.Objects;
import com.tencent.supersonic.common.pojo.DimensionConstants;
import com.tencent.supersonic.headless.api.pojo.enums.DimensionType;
import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.Getter;
import lombok.NoArgsConstructor;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.StringUtils;
import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@Data
@Getter
@@ -65,8 +67,29 @@ public class SchemaElement implements Serializable {
if (MapUtils.isEmpty(extInfo)) {
return false;
}
DimensionType dimensionTYpe = (DimensionType) extInfo.get(DimensionConstants.DIMENSION_TYPE);
Object o = extInfo.get(DimensionConstants.DIMENSION_TYPE);
DimensionType dimensionTYpe = null;
if (o instanceof DimensionType) {
dimensionTYpe = (DimensionType) o;
}
if (o instanceof String) {
dimensionTYpe = DimensionType.valueOf((String) o);
}
return DimensionType.isPartitionTime(dimensionTYpe);
}
public String getTimeFormat() {
if (MapUtils.isEmpty(extInfo)) {
return null;
}
return (String) extInfo.get(DimensionConstants.DIMENSION_TIME_FORMAT);
}
public String getPartitionTimeFormat() {
String timeFormat = getTimeFormat();
if (StringUtils.isNotBlank(timeFormat) && containsPartitionTime()) {
return timeFormat;
}
return null;
}
}

View File

@@ -1,14 +1,17 @@
package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeMode;
import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import java.text.SimpleDateFormat;
import java.util.Date;
import java.util.Objects;
public class S2SqlDateHelper {
@@ -23,7 +26,8 @@ public class S2SqlDateHelper {
return defaultDate;
}
TimeDefaultConfig tagTypeTimeDefaultConfig = dataSetSchema.getTagTypeTimeDefaultConfig();
return getDefaultDate(defaultDate, tagTypeTimeDefaultConfig).getLeft();
String partitionTimeFormat = dataSetSchema.getPartitionTimeFormat();
return getDefaultDate(defaultDate, tagTypeTimeDefaultConfig, partitionTimeFormat).getLeft();
}
public static Pair<String, String> getStartEndDate(ChatQueryContext chatQueryContext, Long dataSetId,
@@ -40,33 +44,53 @@ public class S2SqlDateHelper {
if (QueryType.DETAIL.equals(queryType) && defaultConfig.getUnit() >= 0) {
defaultConfig = dataSetSchema.getTagTypeTimeDefaultConfig();
}
return getDefaultDate(defaultDate, defaultConfig);
String partitionTimeFormat = dataSetSchema.getPartitionTimeFormat();
return getDefaultDate(defaultDate, defaultConfig, partitionTimeFormat);
}
private static Pair<String, String> getDefaultDate(String defaultDate, TimeDefaultConfig defaultConfig) {
if (Objects.isNull(defaultConfig)) {
private static Pair<String, String> getDefaultDate(String defaultDate,
TimeDefaultConfig defaultConfig,
String partitionTimeFormat) {
if (defaultConfig == null) {
return Pair.of(null, null);
}
Integer unit = defaultConfig.getUnit();
String period = defaultConfig.getPeriod();
TimeMode timeMode = defaultConfig.getTimeMode();
if (Objects.nonNull(unit)) {
if (unit == null) {
return Pair.of(defaultDate, defaultDate);
}
// If the unit is set to less than 0, then do not add relative date.
if (unit < 0) {
return Pair.of(null, null);
}
String period = defaultConfig.getPeriod();
TimeMode timeMode = defaultConfig.getTimeMode();
DatePeriodEnum datePeriodEnum = DatePeriodEnum.get(period);
String startDate = DateUtils.getBeforeDate(unit, datePeriodEnum);
String endDate = DateUtils.getBeforeDate(0, DatePeriodEnum.DAY);
if (unit == 0) {
if (unit == 0 || TimeMode.LAST.equals(timeMode)) {
endDate = startDate;
}
if (TimeMode.LAST.equals(timeMode)) {
endDate = startDate;
if (StringUtils.isNotBlank(partitionTimeFormat)) {
startDate = formatDate(startDate, partitionTimeFormat);
endDate = formatDate(endDate, partitionTimeFormat);
}
return Pair.of(startDate, endDate);
}
return Pair.of(defaultDate, defaultDate);
}
private static String formatDate(String dateStr, String format) {
try {
// Assuming the input date format is "yyyy-MM-dd"
SimpleDateFormat inputFormat = new SimpleDateFormat(DateUtils.DATE_FORMAT);
Date date = inputFormat.parse(dateStr);
SimpleDateFormat outputFormat = new SimpleDateFormat(format);
return outputFormat.format(date);
} catch (Exception e) {
// Handle the exception, maybe log it and return the original dateStr
return dateStr;
}
}
}

View File

@@ -11,9 +11,6 @@ import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
@@ -22,6 +19,10 @@ import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.util.CollectionUtils;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
/**
* Perform SQL corrections on the time in S2SQL.
*/
@@ -60,8 +61,11 @@ public class TimeCorrector extends BaseSemanticCorrector {
if (isValidDateRange(startEndDate)) {
correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL);
String dateChName = TimeDimensionEnum.DAY.getChName();
String startDateLeft = startEndDate.getLeft();
String endDateRight = startEndDate.getRight();
String condExpr = String.format(" ( %s >= '%s' and %s <= '%s' )",
dateChName, startEndDate.getLeft(), dateChName, startEndDate.getRight());
dateChName, startDateLeft, dateChName, endDateRight);
correctS2SQL = addConditionToSQL(correctS2SQL, condExpr);
}
}

View File

@@ -16,6 +16,7 @@ import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@@ -76,7 +77,7 @@ public class LLMRequestService {
llmSchema.setDataSetName(dataSetIdToName.get(dataSetId));
llmSchema.setDomainName(dataSetIdToName.get(dataSetId));
List<String> fieldNameList = getMatchedFieldNames(queryCtx, dataSetId);
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, dataSetId);
if (Objects.nonNull(semanticSchema.getDataSetSchemaMap())
&& Objects.nonNull(semanticSchema.getDataSetSchemaMap().get(dataSetId))) {
TimeDefaultConfig timeDefaultConfig = semanticSchema.getDataSetSchemaMap()
@@ -87,14 +88,14 @@ public class LLMRequestService {
fieldNameList.add(TimeDimensionEnum.DAY.getChName());
}
}
llmSchema.setFieldNameList(fieldNameList);
llmSchema.setFieldNameList(new ArrayList<>(fieldNameList));
llmSchema.setMetrics(getMatchedMetrics(queryCtx, dataSetId));
llmSchema.setDimensions(getMatchedDimensions(queryCtx, dataSetId));
llmSchema.setTerms(getTerms(queryCtx, dataSetId));
llmReq.setSchema(llmSchema);
String priorExts = getPriorExts(queryCtx, fieldNameList);
String priorExts = getPriorExts(queryCtx, new ArrayList<>(fieldNameList));
llmReq.setPriorExts(priorExts);
List<LLMReq.ElementValue> linking = new ArrayList<>();
@@ -144,27 +145,41 @@ public class LLMRequestService {
private String getPriorExts(ChatQueryContext queryContext, List<String> fieldNameList) {
StringBuilder extraInfoSb = new StringBuilder();
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
Map<String, String> fieldNameToDataFormatType = semanticSchema.getMetrics()
.stream().filter(metricSchemaResp -> Objects.nonNull(metricSchemaResp.getDataFormatType()))
.flatMap(metricSchemaResp -> {
Set<Pair<String, String>> result = new HashSet<>();
String dataFormatType = metricSchemaResp.getDataFormatType();
result.add(Pair.of(metricSchemaResp.getName(), dataFormatType));
List<String> aliasList = metricSchemaResp.getAlias();
if (!CollectionUtils.isEmpty(aliasList)) {
for (String alias : aliasList) {
result.add(Pair.of(alias, dataFormatType));
}
}
return result.stream();
}).collect(Collectors.toMap(Pair::getLeft, Pair::getRight, (k1, k2) -> k1));
// 获取字段名到数据格式类型的映射
Map<String, String> fieldNameToDataFormatType = semanticSchema.getMetrics().stream()
.filter(metric -> Objects.nonNull(metric.getDataFormatType()))
.flatMap(metric -> {
Set<Pair<String, String>> fieldFormatPairs = new HashSet<>();
String dataFormatType = metric.getDataFormatType();
fieldFormatPairs.add(Pair.of(metric.getName(), dataFormatType));
List<String> aliasList = metric.getAlias();
if (!CollectionUtils.isEmpty(aliasList)) {
aliasList.forEach(alias -> fieldFormatPairs.add(Pair.of(alias, dataFormatType)));
}
return fieldFormatPairs.stream();
})
.collect(Collectors.toMap(Pair::getLeft, Pair::getRight, (existing, replacement) -> existing));
Map<String, String> fieldNameToDateFormat = semanticSchema.getDimensions().stream()
.filter(dimension -> StringUtils.isNotBlank(dimension.getTimeFormat()))
.collect(Collectors.toMap(
SchemaElement::getName, SchemaElement::getPartitionTimeFormat, (k1, k2) -> k1)
);
// 构建额外信息字符串
for (String fieldName : fieldNameList) {
String dataFormatType = fieldNameToDataFormatType.get(fieldName);
if (DataFormatTypeEnum.DECIMAL.getName().equalsIgnoreCase(dataFormatType)
|| DataFormatTypeEnum.PERCENT.getName().equalsIgnoreCase(dataFormatType)) {
String format = String.format("%s的计量单位是%s", fieldName, "小数; ");
extraInfoSb.append(format);
extraInfoSb.append(String.format("%s的计量单位是%s; ", fieldName, "小数"));
}
}
// 构建分区日期格式化信息
for (String fieldName : fieldNameList) {
String timeFormat = fieldNameToDateFormat.get(fieldName);
if (StringUtils.isNotBlank(timeFormat)) {
extraInfoSb.append(String.format("%s的日期Format格式是%s; ", fieldName, timeFormat));
}
}
return extraInfoSb.toString();
@@ -195,8 +210,8 @@ public class LLMRequestService {
protected Map<Long, String> getItemIdToName(ChatQueryContext queryCtx, Long dataSetId) {
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
List<SchemaElement> elements = semanticSchema.getDimensions(dataSetId);
return elements.stream()
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
return elements.stream().collect(
Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
}
protected List<SchemaElement> getMatchedMetrics(ChatQueryContext queryCtx, Long dataSetId) {
@@ -230,13 +245,13 @@ public class LLMRequestService {
.collect(Collectors.toList());
}
protected List<String> getMatchedFieldNames(ChatQueryContext queryCtx, Long dataSetId) {
protected Set<String> getMatchedFieldNames(ChatQueryContext queryCtx, Long dataSetId) {
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, dataSetId);
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
if (CollectionUtils.isEmpty(matchedElements)) {
return new ArrayList<>();
return new HashSet<>();
}
Set<String> fieldNameList = matchedElements.stream()
return matchedElements.stream()
.filter(schemaElementMatch -> {
SchemaElementType elementType = schemaElementMatch.getElement().getType();
return SchemaElementType.METRIC.equals(elementType)
@@ -252,7 +267,5 @@ public class LLMRequestService {
return schemaElementMatch.getWord();
})
.collect(Collectors.toSet());
return new ArrayList<>(fieldNameList);
}
}

View File

@@ -174,9 +174,9 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
parseInfo.getDimensionFilters().add(dimensionFilter);
} else {
QueryFilter dimensionFilter = new QueryFilter();
List<String> vals = new ArrayList<>();
entry.getValue().stream().forEach(i -> vals.add(i.getWord()));
dimensionFilter.setValue(vals);
List<String> values = new ArrayList<>();
entry.getValue().stream().forEach(i -> values.add(i.getWord()));
dimensionFilter.setValue(values);
dimensionFilter.setBizName(dimension.getBizName());
dimensionFilter.setName(dimension.getName());
dimensionFilter.setOperator(FilterOperatorEnum.IN);

View File

@@ -170,6 +170,7 @@ public class DataSetSchemaBuilder {
.type(SchemaElementType.DIMENSION)
.build();
dimToAdd.getExtInfo().put(DimensionConstants.DIMENSION_TYPE, dim.getType());
if (dim.isTimeDimension()) {
String timeFormat = String.valueOf(dim.getExt().get(DimensionConstants.DIMENSION_TIME_FORMAT));
dimToAdd.getExtInfo().put(DimensionConstants.DIMENSION_TIME_FORMAT, timeFormat);