(improvement)(chat) Make corrections and pass the data date format to the large model. (#1583)

This commit is contained in:
lexluo09
2024-08-19 01:21:10 +08:00
committed by GitHub
parent 10a5e485cb
commit ba55ecb31e
8 changed files with 136 additions and 60 deletions

View File

@@ -4,6 +4,7 @@ public class DimensionConstants {
public static final String DIMENSION_TIME_FORMAT = "time_format"; public static final String DIMENSION_TIME_FORMAT = "time_format";
public static final String DIMENSION_TYPE = "dimension_type"; public static final String DIMENSION_TYPE = "dimension_type";
} }

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.api.pojo;
import lombok.Data; import lombok.Data;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashSet; import java.util.HashSet;
@@ -126,4 +127,13 @@ public class DataSetSchema {
return dimensions.stream().anyMatch(SchemaElement::containsPartitionTime); return dimensions.stream().anyMatch(SchemaElement::containsPartitionTime);
} }
public String getPartitionTimeFormat() {
for (SchemaElement dimension : dimensions) {
String partitionTimeFormat = dimension.getPartitionTimeFormat();
if (StringUtils.isNotBlank(partitionTimeFormat)) {
return partitionTimeFormat;
}
}
return null;
}
} }

View File

@@ -3,16 +3,18 @@ package com.tencent.supersonic.headless.api.pojo;
import com.google.common.base.Objects; import com.google.common.base.Objects;
import com.tencent.supersonic.common.pojo.DimensionConstants; import com.tencent.supersonic.common.pojo.DimensionConstants;
import com.tencent.supersonic.headless.api.pojo.enums.DimensionType; import com.tencent.supersonic.headless.api.pojo.enums.DimensionType;
import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.Getter; import lombok.Getter;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.apache.commons.collections4.MapUtils; import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.StringUtils;
import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@Data @Data
@Getter @Getter
@@ -65,8 +67,29 @@ public class SchemaElement implements Serializable {
if (MapUtils.isEmpty(extInfo)) { if (MapUtils.isEmpty(extInfo)) {
return false; return false;
} }
DimensionType dimensionTYpe = (DimensionType) extInfo.get(DimensionConstants.DIMENSION_TYPE); Object o = extInfo.get(DimensionConstants.DIMENSION_TYPE);
DimensionType dimensionTYpe = null;
if (o instanceof DimensionType) {
dimensionTYpe = (DimensionType) o;
}
if (o instanceof String) {
dimensionTYpe = DimensionType.valueOf((String) o);
}
return DimensionType.isPartitionTime(dimensionTYpe); return DimensionType.isPartitionTime(dimensionTYpe);
} }
public String getTimeFormat() {
if (MapUtils.isEmpty(extInfo)) {
return null;
}
return (String) extInfo.get(DimensionConstants.DIMENSION_TIME_FORMAT);
}
public String getPartitionTimeFormat() {
String timeFormat = getTimeFormat();
if (StringUtils.isNotBlank(timeFormat) && containsPartitionTime()) {
return timeFormat;
}
return null;
}
} }

View File

@@ -1,14 +1,17 @@
package com.tencent.supersonic.headless.chat.corrector; package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeMode; import com.tencent.supersonic.common.pojo.enums.TimeMode;
import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum;
import com.tencent.supersonic.common.util.DateUtils; import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig; import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.ChatQueryContext;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Pair;
import java.text.SimpleDateFormat;
import java.util.Date;
import java.util.Objects; import java.util.Objects;
public class S2SqlDateHelper { public class S2SqlDateHelper {
@@ -23,7 +26,8 @@ public class S2SqlDateHelper {
return defaultDate; return defaultDate;
} }
TimeDefaultConfig tagTypeTimeDefaultConfig = dataSetSchema.getTagTypeTimeDefaultConfig(); TimeDefaultConfig tagTypeTimeDefaultConfig = dataSetSchema.getTagTypeTimeDefaultConfig();
return getDefaultDate(defaultDate, tagTypeTimeDefaultConfig).getLeft(); String partitionTimeFormat = dataSetSchema.getPartitionTimeFormat();
return getDefaultDate(defaultDate, tagTypeTimeDefaultConfig, partitionTimeFormat).getLeft();
} }
public static Pair<String, String> getStartEndDate(ChatQueryContext chatQueryContext, Long dataSetId, public static Pair<String, String> getStartEndDate(ChatQueryContext chatQueryContext, Long dataSetId,
@@ -40,33 +44,53 @@ public class S2SqlDateHelper {
if (QueryType.DETAIL.equals(queryType) && defaultConfig.getUnit() >= 0) { if (QueryType.DETAIL.equals(queryType) && defaultConfig.getUnit() >= 0) {
defaultConfig = dataSetSchema.getTagTypeTimeDefaultConfig(); defaultConfig = dataSetSchema.getTagTypeTimeDefaultConfig();
} }
return getDefaultDate(defaultDate, defaultConfig); String partitionTimeFormat = dataSetSchema.getPartitionTimeFormat();
return getDefaultDate(defaultDate, defaultConfig, partitionTimeFormat);
} }
private static Pair<String, String> getDefaultDate(String defaultDate, TimeDefaultConfig defaultConfig) { private static Pair<String, String> getDefaultDate(String defaultDate,
if (Objects.isNull(defaultConfig)) { TimeDefaultConfig defaultConfig,
String partitionTimeFormat) {
if (defaultConfig == null) {
return Pair.of(null, null); return Pair.of(null, null);
} }
Integer unit = defaultConfig.getUnit(); Integer unit = defaultConfig.getUnit();
if (unit == null) {
return Pair.of(defaultDate, defaultDate);
}
// If the unit is set to less than 0, then do not add relative date.
if (unit < 0) {
return Pair.of(null, null);
}
String period = defaultConfig.getPeriod(); String period = defaultConfig.getPeriod();
TimeMode timeMode = defaultConfig.getTimeMode(); TimeMode timeMode = defaultConfig.getTimeMode();
if (Objects.nonNull(unit)) { DatePeriodEnum datePeriodEnum = DatePeriodEnum.get(period);
// If the unit is set to less than 0, then do not add relative date.
if (unit < 0) { String startDate = DateUtils.getBeforeDate(unit, datePeriodEnum);
return Pair.of(null, null); String endDate = DateUtils.getBeforeDate(0, DatePeriodEnum.DAY);
}
DatePeriodEnum datePeriodEnum = DatePeriodEnum.get(period); if (unit == 0 || TimeMode.LAST.equals(timeMode)) {
String startDate = DateUtils.getBeforeDate(unit, datePeriodEnum); endDate = startDate;
String endDate = DateUtils.getBeforeDate(0, DatePeriodEnum.DAY);
if (unit == 0) {
endDate = startDate;
}
if (TimeMode.LAST.equals(timeMode)) {
endDate = startDate;
}
return Pair.of(startDate, endDate);
} }
return Pair.of(defaultDate, defaultDate); if (StringUtils.isNotBlank(partitionTimeFormat)) {
startDate = formatDate(startDate, partitionTimeFormat);
endDate = formatDate(endDate, partitionTimeFormat);
}
return Pair.of(startDate, endDate);
} }
private static String formatDate(String dateStr, String format) {
try {
// Assuming the input date format is "yyyy-MM-dd"
SimpleDateFormat inputFormat = new SimpleDateFormat(DateUtils.DATE_FORMAT);
Date date = inputFormat.parse(dateStr);
SimpleDateFormat outputFormat = new SimpleDateFormat(format);
return outputFormat.format(date);
} catch (Exception e) {
// Handle the exception, maybe log it and return the original dateStr
return dateStr;
}
}
} }

View File

@@ -11,9 +11,6 @@ import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.ChatQueryContext;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Expression;
@@ -22,6 +19,10 @@ import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Pair;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
/** /**
* Perform SQL corrections on the time in S2SQL. * Perform SQL corrections on the time in S2SQL.
*/ */
@@ -60,8 +61,11 @@ public class TimeCorrector extends BaseSemanticCorrector {
if (isValidDateRange(startEndDate)) { if (isValidDateRange(startEndDate)) {
correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL); correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL);
String dateChName = TimeDimensionEnum.DAY.getChName(); String dateChName = TimeDimensionEnum.DAY.getChName();
String startDateLeft = startEndDate.getLeft();
String endDateRight = startEndDate.getRight();
String condExpr = String.format(" ( %s >= '%s' and %s <= '%s' )", String condExpr = String.format(" ( %s >= '%s' and %s <= '%s' )",
dateChName, startEndDate.getLeft(), dateChName, startEndDate.getRight()); dateChName, startDateLeft, dateChName, endDateRight);
correctS2SQL = addConditionToSQL(correctS2SQL, condExpr); correctS2SQL = addConditionToSQL(correctS2SQL, condExpr);
} }
} }
@@ -69,7 +73,7 @@ public class TimeCorrector extends BaseSemanticCorrector {
} }
private boolean containsPartitionDimensions(ChatQueryContext chatQueryContext, private boolean containsPartitionDimensions(ChatQueryContext chatQueryContext,
SemanticParseInfo semanticParseInfo) { SemanticParseInfo semanticParseInfo) {
Long dataSetId = semanticParseInfo.getDataSetId(); Long dataSetId = semanticParseInfo.getDataSetId();
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema(); SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId); DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId);

View File

@@ -16,6 +16,7 @@ import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.headless.chat.utils.ComponentFactory; import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@@ -76,7 +77,7 @@ public class LLMRequestService {
llmSchema.setDataSetName(dataSetIdToName.get(dataSetId)); llmSchema.setDataSetName(dataSetIdToName.get(dataSetId));
llmSchema.setDomainName(dataSetIdToName.get(dataSetId)); llmSchema.setDomainName(dataSetIdToName.get(dataSetId));
List<String> fieldNameList = getMatchedFieldNames(queryCtx, dataSetId); Set<String> fieldNameList = getMatchedFieldNames(queryCtx, dataSetId);
if (Objects.nonNull(semanticSchema.getDataSetSchemaMap()) if (Objects.nonNull(semanticSchema.getDataSetSchemaMap())
&& Objects.nonNull(semanticSchema.getDataSetSchemaMap().get(dataSetId))) { && Objects.nonNull(semanticSchema.getDataSetSchemaMap().get(dataSetId))) {
TimeDefaultConfig timeDefaultConfig = semanticSchema.getDataSetSchemaMap() TimeDefaultConfig timeDefaultConfig = semanticSchema.getDataSetSchemaMap()
@@ -87,14 +88,14 @@ public class LLMRequestService {
fieldNameList.add(TimeDimensionEnum.DAY.getChName()); fieldNameList.add(TimeDimensionEnum.DAY.getChName());
} }
} }
llmSchema.setFieldNameList(fieldNameList); llmSchema.setFieldNameList(new ArrayList<>(fieldNameList));
llmSchema.setMetrics(getMatchedMetrics(queryCtx, dataSetId)); llmSchema.setMetrics(getMatchedMetrics(queryCtx, dataSetId));
llmSchema.setDimensions(getMatchedDimensions(queryCtx, dataSetId)); llmSchema.setDimensions(getMatchedDimensions(queryCtx, dataSetId));
llmSchema.setTerms(getTerms(queryCtx, dataSetId)); llmSchema.setTerms(getTerms(queryCtx, dataSetId));
llmReq.setSchema(llmSchema); llmReq.setSchema(llmSchema);
String priorExts = getPriorExts(queryCtx, fieldNameList); String priorExts = getPriorExts(queryCtx, new ArrayList<>(fieldNameList));
llmReq.setPriorExts(priorExts); llmReq.setPriorExts(priorExts);
List<LLMReq.ElementValue> linking = new ArrayList<>(); List<LLMReq.ElementValue> linking = new ArrayList<>();
@@ -144,27 +145,41 @@ public class LLMRequestService {
private String getPriorExts(ChatQueryContext queryContext, List<String> fieldNameList) { private String getPriorExts(ChatQueryContext queryContext, List<String> fieldNameList) {
StringBuilder extraInfoSb = new StringBuilder(); StringBuilder extraInfoSb = new StringBuilder();
SemanticSchema semanticSchema = queryContext.getSemanticSchema(); SemanticSchema semanticSchema = queryContext.getSemanticSchema();
Map<String, String> fieldNameToDataFormatType = semanticSchema.getMetrics()
.stream().filter(metricSchemaResp -> Objects.nonNull(metricSchemaResp.getDataFormatType()))
.flatMap(metricSchemaResp -> {
Set<Pair<String, String>> result = new HashSet<>();
String dataFormatType = metricSchemaResp.getDataFormatType();
result.add(Pair.of(metricSchemaResp.getName(), dataFormatType));
List<String> aliasList = metricSchemaResp.getAlias();
if (!CollectionUtils.isEmpty(aliasList)) {
for (String alias : aliasList) {
result.add(Pair.of(alias, dataFormatType));
}
}
return result.stream();
}).collect(Collectors.toMap(Pair::getLeft, Pair::getRight, (k1, k2) -> k1));
// 获取字段名到数据格式类型的映射
Map<String, String> fieldNameToDataFormatType = semanticSchema.getMetrics().stream()
.filter(metric -> Objects.nonNull(metric.getDataFormatType()))
.flatMap(metric -> {
Set<Pair<String, String>> fieldFormatPairs = new HashSet<>();
String dataFormatType = metric.getDataFormatType();
fieldFormatPairs.add(Pair.of(metric.getName(), dataFormatType));
List<String> aliasList = metric.getAlias();
if (!CollectionUtils.isEmpty(aliasList)) {
aliasList.forEach(alias -> fieldFormatPairs.add(Pair.of(alias, dataFormatType)));
}
return fieldFormatPairs.stream();
})
.collect(Collectors.toMap(Pair::getLeft, Pair::getRight, (existing, replacement) -> existing));
Map<String, String> fieldNameToDateFormat = semanticSchema.getDimensions().stream()
.filter(dimension -> StringUtils.isNotBlank(dimension.getTimeFormat()))
.collect(Collectors.toMap(
SchemaElement::getName, SchemaElement::getPartitionTimeFormat, (k1, k2) -> k1)
);
// 构建额外信息字符串
for (String fieldName : fieldNameList) { for (String fieldName : fieldNameList) {
String dataFormatType = fieldNameToDataFormatType.get(fieldName); String dataFormatType = fieldNameToDataFormatType.get(fieldName);
if (DataFormatTypeEnum.DECIMAL.getName().equalsIgnoreCase(dataFormatType) if (DataFormatTypeEnum.DECIMAL.getName().equalsIgnoreCase(dataFormatType)
|| DataFormatTypeEnum.PERCENT.getName().equalsIgnoreCase(dataFormatType)) { || DataFormatTypeEnum.PERCENT.getName().equalsIgnoreCase(dataFormatType)) {
String format = String.format("%s的计量单位是%s", fieldName, "小数; "); extraInfoSb.append(String.format("%s的计量单位是%s; ", fieldName, "小数"));
extraInfoSb.append(format); }
}
// 构建分区日期格式化信息
for (String fieldName : fieldNameList) {
String timeFormat = fieldNameToDateFormat.get(fieldName);
if (StringUtils.isNotBlank(timeFormat)) {
extraInfoSb.append(String.format("%s的日期Format格式是%s; ", fieldName, timeFormat));
} }
} }
return extraInfoSb.toString(); return extraInfoSb.toString();
@@ -195,8 +210,8 @@ public class LLMRequestService {
protected Map<Long, String> getItemIdToName(ChatQueryContext queryCtx, Long dataSetId) { protected Map<Long, String> getItemIdToName(ChatQueryContext queryCtx, Long dataSetId) {
SemanticSchema semanticSchema = queryCtx.getSemanticSchema(); SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
List<SchemaElement> elements = semanticSchema.getDimensions(dataSetId); List<SchemaElement> elements = semanticSchema.getDimensions(dataSetId);
return elements.stream() return elements.stream().collect(
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2)); Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
} }
protected List<SchemaElement> getMatchedMetrics(ChatQueryContext queryCtx, Long dataSetId) { protected List<SchemaElement> getMatchedMetrics(ChatQueryContext queryCtx, Long dataSetId) {
@@ -230,13 +245,13 @@ public class LLMRequestService {
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
protected List<String> getMatchedFieldNames(ChatQueryContext queryCtx, Long dataSetId) { protected Set<String> getMatchedFieldNames(ChatQueryContext queryCtx, Long dataSetId) {
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, dataSetId); Map<Long, String> itemIdToName = getItemIdToName(queryCtx, dataSetId);
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId); List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
if (CollectionUtils.isEmpty(matchedElements)) { if (CollectionUtils.isEmpty(matchedElements)) {
return new ArrayList<>(); return new HashSet<>();
} }
Set<String> fieldNameList = matchedElements.stream() return matchedElements.stream()
.filter(schemaElementMatch -> { .filter(schemaElementMatch -> {
SchemaElementType elementType = schemaElementMatch.getElement().getType(); SchemaElementType elementType = schemaElementMatch.getElement().getType();
return SchemaElementType.METRIC.equals(elementType) return SchemaElementType.METRIC.equals(elementType)
@@ -252,7 +267,5 @@ public class LLMRequestService {
return schemaElementMatch.getWord(); return schemaElementMatch.getWord();
}) })
.collect(Collectors.toSet()); .collect(Collectors.toSet());
return new ArrayList<>(fieldNameList);
} }
} }

View File

@@ -174,9 +174,9 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
parseInfo.getDimensionFilters().add(dimensionFilter); parseInfo.getDimensionFilters().add(dimensionFilter);
} else { } else {
QueryFilter dimensionFilter = new QueryFilter(); QueryFilter dimensionFilter = new QueryFilter();
List<String> vals = new ArrayList<>(); List<String> values = new ArrayList<>();
entry.getValue().stream().forEach(i -> vals.add(i.getWord())); entry.getValue().stream().forEach(i -> values.add(i.getWord()));
dimensionFilter.setValue(vals); dimensionFilter.setValue(values);
dimensionFilter.setBizName(dimension.getBizName()); dimensionFilter.setBizName(dimension.getBizName());
dimensionFilter.setName(dimension.getName()); dimensionFilter.setName(dimension.getName());
dimensionFilter.setOperator(FilterOperatorEnum.IN); dimensionFilter.setOperator(FilterOperatorEnum.IN);

View File

@@ -170,6 +170,7 @@ public class DataSetSchemaBuilder {
.type(SchemaElementType.DIMENSION) .type(SchemaElementType.DIMENSION)
.build(); .build();
dimToAdd.getExtInfo().put(DimensionConstants.DIMENSION_TYPE, dim.getType()); dimToAdd.getExtInfo().put(DimensionConstants.DIMENSION_TYPE, dim.getType());
if (dim.isTimeDimension()) { if (dim.isTimeDimension()) {
String timeFormat = String.valueOf(dim.getExt().get(DimensionConstants.DIMENSION_TIME_FORMAT)); String timeFormat = String.valueOf(dim.getExt().get(DimensionConstants.DIMENSION_TIME_FORMAT));
dimToAdd.getExtInfo().put(DimensionConstants.DIMENSION_TIME_FORMAT, timeFormat); dimToAdd.getExtInfo().put(DimensionConstants.DIMENSION_TIME_FORMAT, timeFormat);