mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 12:07:42 +00:00
(improvement)(chat) Make corrections and pass the data date format to the large model. (#1583)
This commit is contained in:
@@ -4,6 +4,7 @@ public class DimensionConstants {
|
||||
|
||||
public static final String DIMENSION_TIME_FORMAT = "time_format";
|
||||
|
||||
|
||||
public static final String DIMENSION_TYPE = "dimension_type";
|
||||
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.api.pojo;
|
||||
|
||||
import lombok.Data;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
@@ -126,4 +127,13 @@ public class DataSetSchema {
|
||||
return dimensions.stream().anyMatch(SchemaElement::containsPartitionTime);
|
||||
}
|
||||
|
||||
public String getPartitionTimeFormat() {
|
||||
for (SchemaElement dimension : dimensions) {
|
||||
String partitionTimeFormat = dimension.getPartitionTimeFormat();
|
||||
if (StringUtils.isNotBlank(partitionTimeFormat)) {
|
||||
return partitionTimeFormat;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,16 +3,18 @@ package com.tencent.supersonic.headless.api.pojo;
|
||||
import com.google.common.base.Objects;
|
||||
import com.tencent.supersonic.common.pojo.DimensionConstants;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.DimensionType;
|
||||
import java.io.Serializable;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.apache.commons.collections4.MapUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
@Getter
|
||||
@@ -65,8 +67,29 @@ public class SchemaElement implements Serializable {
|
||||
if (MapUtils.isEmpty(extInfo)) {
|
||||
return false;
|
||||
}
|
||||
DimensionType dimensionTYpe = (DimensionType) extInfo.get(DimensionConstants.DIMENSION_TYPE);
|
||||
Object o = extInfo.get(DimensionConstants.DIMENSION_TYPE);
|
||||
DimensionType dimensionTYpe = null;
|
||||
if (o instanceof DimensionType) {
|
||||
dimensionTYpe = (DimensionType) o;
|
||||
}
|
||||
if (o instanceof String) {
|
||||
dimensionTYpe = DimensionType.valueOf((String) o);
|
||||
}
|
||||
return DimensionType.isPartitionTime(dimensionTYpe);
|
||||
}
|
||||
|
||||
public String getTimeFormat() {
|
||||
if (MapUtils.isEmpty(extInfo)) {
|
||||
return null;
|
||||
}
|
||||
return (String) extInfo.get(DimensionConstants.DIMENSION_TIME_FORMAT);
|
||||
}
|
||||
|
||||
public String getPartitionTimeFormat() {
|
||||
String timeFormat = getTimeFormat();
|
||||
if (StringUtils.isNotBlank(timeFormat) && containsPartitionTime()) {
|
||||
return timeFormat;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
package com.tencent.supersonic.headless.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeMode;
|
||||
import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
|
||||
import java.text.SimpleDateFormat;
|
||||
import java.util.Date;
|
||||
import java.util.Objects;
|
||||
|
||||
public class S2SqlDateHelper {
|
||||
@@ -23,7 +26,8 @@ public class S2SqlDateHelper {
|
||||
return defaultDate;
|
||||
}
|
||||
TimeDefaultConfig tagTypeTimeDefaultConfig = dataSetSchema.getTagTypeTimeDefaultConfig();
|
||||
return getDefaultDate(defaultDate, tagTypeTimeDefaultConfig).getLeft();
|
||||
String partitionTimeFormat = dataSetSchema.getPartitionTimeFormat();
|
||||
return getDefaultDate(defaultDate, tagTypeTimeDefaultConfig, partitionTimeFormat).getLeft();
|
||||
}
|
||||
|
||||
public static Pair<String, String> getStartEndDate(ChatQueryContext chatQueryContext, Long dataSetId,
|
||||
@@ -40,33 +44,53 @@ public class S2SqlDateHelper {
|
||||
if (QueryType.DETAIL.equals(queryType) && defaultConfig.getUnit() >= 0) {
|
||||
defaultConfig = dataSetSchema.getTagTypeTimeDefaultConfig();
|
||||
}
|
||||
return getDefaultDate(defaultDate, defaultConfig);
|
||||
String partitionTimeFormat = dataSetSchema.getPartitionTimeFormat();
|
||||
return getDefaultDate(defaultDate, defaultConfig, partitionTimeFormat);
|
||||
}
|
||||
|
||||
private static Pair<String, String> getDefaultDate(String defaultDate, TimeDefaultConfig defaultConfig) {
|
||||
if (Objects.isNull(defaultConfig)) {
|
||||
private static Pair<String, String> getDefaultDate(String defaultDate,
|
||||
TimeDefaultConfig defaultConfig,
|
||||
String partitionTimeFormat) {
|
||||
if (defaultConfig == null) {
|
||||
return Pair.of(null, null);
|
||||
}
|
||||
Integer unit = defaultConfig.getUnit();
|
||||
String period = defaultConfig.getPeriod();
|
||||
TimeMode timeMode = defaultConfig.getTimeMode();
|
||||
if (Objects.nonNull(unit)) {
|
||||
if (unit == null) {
|
||||
return Pair.of(defaultDate, defaultDate);
|
||||
}
|
||||
|
||||
// If the unit is set to less than 0, then do not add relative date.
|
||||
if (unit < 0) {
|
||||
return Pair.of(null, null);
|
||||
}
|
||||
|
||||
String period = defaultConfig.getPeriod();
|
||||
TimeMode timeMode = defaultConfig.getTimeMode();
|
||||
DatePeriodEnum datePeriodEnum = DatePeriodEnum.get(period);
|
||||
|
||||
String startDate = DateUtils.getBeforeDate(unit, datePeriodEnum);
|
||||
String endDate = DateUtils.getBeforeDate(0, DatePeriodEnum.DAY);
|
||||
if (unit == 0) {
|
||||
|
||||
if (unit == 0 || TimeMode.LAST.equals(timeMode)) {
|
||||
endDate = startDate;
|
||||
}
|
||||
if (TimeMode.LAST.equals(timeMode)) {
|
||||
endDate = startDate;
|
||||
if (StringUtils.isNotBlank(partitionTimeFormat)) {
|
||||
startDate = formatDate(startDate, partitionTimeFormat);
|
||||
endDate = formatDate(endDate, partitionTimeFormat);
|
||||
}
|
||||
return Pair.of(startDate, endDate);
|
||||
}
|
||||
return Pair.of(defaultDate, defaultDate);
|
||||
}
|
||||
|
||||
private static String formatDate(String dateStr, String format) {
|
||||
try {
|
||||
// Assuming the input date format is "yyyy-MM-dd"
|
||||
SimpleDateFormat inputFormat = new SimpleDateFormat(DateUtils.DATE_FORMAT);
|
||||
Date date = inputFormat.parse(dateStr);
|
||||
SimpleDateFormat outputFormat = new SimpleDateFormat(format);
|
||||
return outputFormat.format(date);
|
||||
} catch (Exception e) {
|
||||
// Handle the exception, maybe log it and return the original dateStr
|
||||
return dateStr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,9 +11,6 @@ import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
@@ -22,6 +19,10 @@ import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the time in S2SQL.
|
||||
*/
|
||||
@@ -60,8 +61,11 @@ public class TimeCorrector extends BaseSemanticCorrector {
|
||||
if (isValidDateRange(startEndDate)) {
|
||||
correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL);
|
||||
String dateChName = TimeDimensionEnum.DAY.getChName();
|
||||
String startDateLeft = startEndDate.getLeft();
|
||||
String endDateRight = startEndDate.getRight();
|
||||
|
||||
String condExpr = String.format(" ( %s >= '%s' and %s <= '%s' )",
|
||||
dateChName, startEndDate.getLeft(), dateChName, startEndDate.getRight());
|
||||
dateChName, startDateLeft, dateChName, endDateRight);
|
||||
correctS2SQL = addConditionToSQL(correctS2SQL, condExpr);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
@@ -76,7 +77,7 @@ public class LLMRequestService {
|
||||
llmSchema.setDataSetName(dataSetIdToName.get(dataSetId));
|
||||
llmSchema.setDomainName(dataSetIdToName.get(dataSetId));
|
||||
|
||||
List<String> fieldNameList = getMatchedFieldNames(queryCtx, dataSetId);
|
||||
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, dataSetId);
|
||||
if (Objects.nonNull(semanticSchema.getDataSetSchemaMap())
|
||||
&& Objects.nonNull(semanticSchema.getDataSetSchemaMap().get(dataSetId))) {
|
||||
TimeDefaultConfig timeDefaultConfig = semanticSchema.getDataSetSchemaMap()
|
||||
@@ -87,14 +88,14 @@ public class LLMRequestService {
|
||||
fieldNameList.add(TimeDimensionEnum.DAY.getChName());
|
||||
}
|
||||
}
|
||||
llmSchema.setFieldNameList(fieldNameList);
|
||||
llmSchema.setFieldNameList(new ArrayList<>(fieldNameList));
|
||||
|
||||
llmSchema.setMetrics(getMatchedMetrics(queryCtx, dataSetId));
|
||||
llmSchema.setDimensions(getMatchedDimensions(queryCtx, dataSetId));
|
||||
llmSchema.setTerms(getTerms(queryCtx, dataSetId));
|
||||
llmReq.setSchema(llmSchema);
|
||||
|
||||
String priorExts = getPriorExts(queryCtx, fieldNameList);
|
||||
String priorExts = getPriorExts(queryCtx, new ArrayList<>(fieldNameList));
|
||||
llmReq.setPriorExts(priorExts);
|
||||
|
||||
List<LLMReq.ElementValue> linking = new ArrayList<>();
|
||||
@@ -144,27 +145,41 @@ public class LLMRequestService {
|
||||
private String getPriorExts(ChatQueryContext queryContext, List<String> fieldNameList) {
|
||||
StringBuilder extraInfoSb = new StringBuilder();
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
Map<String, String> fieldNameToDataFormatType = semanticSchema.getMetrics()
|
||||
.stream().filter(metricSchemaResp -> Objects.nonNull(metricSchemaResp.getDataFormatType()))
|
||||
.flatMap(metricSchemaResp -> {
|
||||
Set<Pair<String, String>> result = new HashSet<>();
|
||||
String dataFormatType = metricSchemaResp.getDataFormatType();
|
||||
result.add(Pair.of(metricSchemaResp.getName(), dataFormatType));
|
||||
List<String> aliasList = metricSchemaResp.getAlias();
|
||||
if (!CollectionUtils.isEmpty(aliasList)) {
|
||||
for (String alias : aliasList) {
|
||||
result.add(Pair.of(alias, dataFormatType));
|
||||
}
|
||||
}
|
||||
return result.stream();
|
||||
}).collect(Collectors.toMap(Pair::getLeft, Pair::getRight, (k1, k2) -> k1));
|
||||
|
||||
// 获取字段名到数据格式类型的映射
|
||||
Map<String, String> fieldNameToDataFormatType = semanticSchema.getMetrics().stream()
|
||||
.filter(metric -> Objects.nonNull(metric.getDataFormatType()))
|
||||
.flatMap(metric -> {
|
||||
Set<Pair<String, String>> fieldFormatPairs = new HashSet<>();
|
||||
String dataFormatType = metric.getDataFormatType();
|
||||
fieldFormatPairs.add(Pair.of(metric.getName(), dataFormatType));
|
||||
List<String> aliasList = metric.getAlias();
|
||||
if (!CollectionUtils.isEmpty(aliasList)) {
|
||||
aliasList.forEach(alias -> fieldFormatPairs.add(Pair.of(alias, dataFormatType)));
|
||||
}
|
||||
return fieldFormatPairs.stream();
|
||||
})
|
||||
.collect(Collectors.toMap(Pair::getLeft, Pair::getRight, (existing, replacement) -> existing));
|
||||
|
||||
Map<String, String> fieldNameToDateFormat = semanticSchema.getDimensions().stream()
|
||||
.filter(dimension -> StringUtils.isNotBlank(dimension.getTimeFormat()))
|
||||
.collect(Collectors.toMap(
|
||||
SchemaElement::getName, SchemaElement::getPartitionTimeFormat, (k1, k2) -> k1)
|
||||
);
|
||||
|
||||
// 构建额外信息字符串
|
||||
for (String fieldName : fieldNameList) {
|
||||
String dataFormatType = fieldNameToDataFormatType.get(fieldName);
|
||||
if (DataFormatTypeEnum.DECIMAL.getName().equalsIgnoreCase(dataFormatType)
|
||||
|| DataFormatTypeEnum.PERCENT.getName().equalsIgnoreCase(dataFormatType)) {
|
||||
String format = String.format("%s的计量单位是%s", fieldName, "小数; ");
|
||||
extraInfoSb.append(format);
|
||||
extraInfoSb.append(String.format("%s的计量单位是%s; ", fieldName, "小数"));
|
||||
}
|
||||
}
|
||||
// 构建分区日期格式化信息
|
||||
for (String fieldName : fieldNameList) {
|
||||
String timeFormat = fieldNameToDateFormat.get(fieldName);
|
||||
if (StringUtils.isNotBlank(timeFormat)) {
|
||||
extraInfoSb.append(String.format("%s的日期Format格式是%s; ", fieldName, timeFormat));
|
||||
}
|
||||
}
|
||||
return extraInfoSb.toString();
|
||||
@@ -195,8 +210,8 @@ public class LLMRequestService {
|
||||
protected Map<Long, String> getItemIdToName(ChatQueryContext queryCtx, Long dataSetId) {
|
||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||
List<SchemaElement> elements = semanticSchema.getDimensions(dataSetId);
|
||||
return elements.stream()
|
||||
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
|
||||
return elements.stream().collect(
|
||||
Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
|
||||
}
|
||||
|
||||
protected List<SchemaElement> getMatchedMetrics(ChatQueryContext queryCtx, Long dataSetId) {
|
||||
@@ -230,13 +245,13 @@ public class LLMRequestService {
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
protected List<String> getMatchedFieldNames(ChatQueryContext queryCtx, Long dataSetId) {
|
||||
protected Set<String> getMatchedFieldNames(ChatQueryContext queryCtx, Long dataSetId) {
|
||||
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, dataSetId);
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||
return new ArrayList<>();
|
||||
return new HashSet<>();
|
||||
}
|
||||
Set<String> fieldNameList = matchedElements.stream()
|
||||
return matchedElements.stream()
|
||||
.filter(schemaElementMatch -> {
|
||||
SchemaElementType elementType = schemaElementMatch.getElement().getType();
|
||||
return SchemaElementType.METRIC.equals(elementType)
|
||||
@@ -252,7 +267,5 @@ public class LLMRequestService {
|
||||
return schemaElementMatch.getWord();
|
||||
})
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
return new ArrayList<>(fieldNameList);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -174,9 +174,9 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
parseInfo.getDimensionFilters().add(dimensionFilter);
|
||||
} else {
|
||||
QueryFilter dimensionFilter = new QueryFilter();
|
||||
List<String> vals = new ArrayList<>();
|
||||
entry.getValue().stream().forEach(i -> vals.add(i.getWord()));
|
||||
dimensionFilter.setValue(vals);
|
||||
List<String> values = new ArrayList<>();
|
||||
entry.getValue().stream().forEach(i -> values.add(i.getWord()));
|
||||
dimensionFilter.setValue(values);
|
||||
dimensionFilter.setBizName(dimension.getBizName());
|
||||
dimensionFilter.setName(dimension.getName());
|
||||
dimensionFilter.setOperator(FilterOperatorEnum.IN);
|
||||
|
||||
@@ -170,6 +170,7 @@ public class DataSetSchemaBuilder {
|
||||
.type(SchemaElementType.DIMENSION)
|
||||
.build();
|
||||
dimToAdd.getExtInfo().put(DimensionConstants.DIMENSION_TYPE, dim.getType());
|
||||
|
||||
if (dim.isTimeDimension()) {
|
||||
String timeFormat = String.valueOf(dim.getExt().get(DimensionConstants.DIMENSION_TIME_FORMAT));
|
||||
dimToAdd.getExtInfo().put(DimensionConstants.DIMENSION_TIME_FORMAT, timeFormat);
|
||||
|
||||
Reference in New Issue
Block a user