[improvement][headless]Add metricFormat field to metric metadata of the Text2SQL prompt. #1621

This commit is contained in:
jerryjzhang
2024-09-12 18:13:44 +08:00
parent 2fa3bfe019
commit 693356e46a
3 changed files with 46 additions and 73 deletions

View File

@@ -1,6 +1,5 @@
package com.tencent.supersonic.headless.chat.parser.llm;
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
@@ -14,8 +13,8 @@ import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.jetbrains.annotations.NotNull;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
@@ -26,7 +25,6 @@ import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
@@ -69,7 +67,6 @@ public class LLMRequestService {
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
llmSchema.setDataSetId(dataSetId);
llmSchema.setDataSetName(dataSetIdToName.get(dataSetId));
llmSchema.setMetrics(getMatchedMetrics(queryCtx, dataSetId));
llmSchema.setDimensions(getMatchedDimensions(queryCtx, dataSetId));
llmSchema.setPartitionTime(getPartitionTime(queryCtx, dataSetId));
@@ -77,9 +74,6 @@ public class LLMRequestService {
llmSchema.setTerms(getTerms(queryCtx, dataSetId));
llmReq.setSchema(llmSchema);
String priorKnowledge = getPriorKnowledge(queryCtx, llmSchema);
llmReq.setPriorExts(priorKnowledge);
List<LLMReq.ElementValue> linking = new ArrayList<>();
boolean linkingValueEnabled =
Boolean.valueOf(parserConfig.getParameterValue(PARSER_LINKING_VALUE_ENABLE));
@@ -132,14 +126,6 @@ public class LLMRequestService {
.collect(Collectors.toList());
}
private String getPriorKnowledge(ChatQueryContext queryContext, LLMReq.LLMSchema llmSchema) {
StringBuilder priorKnowledgeBuilder = new StringBuilder();
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
appendMetricPriorKnowledge(llmSchema, priorKnowledgeBuilder, semanticSchema);
return priorKnowledgeBuilder.toString();
}
private Map<String, String> getFieldNameToDataFormatTypeMap(SemanticSchema semanticSchema) {
return semanticSchema.getMetrics().stream()
.filter(metric -> Objects.nonNull(metric.getDataFormatType()))
@@ -164,34 +150,7 @@ public class LLMRequestService {
(existing, replacement) -> existing));
}
private void appendMetricPriorKnowledge(
LLMReq.LLMSchema llmSchema,
StringBuilder priorKnowledgeBuilder,
SemanticSchema semanticSchema) {
Map<String, String> fieldNameToDataFormatType =
getFieldNameToDataFormatTypeMap(semanticSchema);
for (SchemaElement schemaElement : llmSchema.getMetrics()) {
String fieldName = schemaElement.getName();
String dataFormatType = fieldNameToDataFormatType.get(fieldName);
if (DataFormatTypeEnum.DECIMAL.getName().equalsIgnoreCase(dataFormatType)
|| DataFormatTypeEnum.PERCENT.getName().equalsIgnoreCase(dataFormatType)) {
priorKnowledgeBuilder.append(String.format("%s的计量单位是%s; ", fieldName, "小数"));
}
}
}
private Map<String, String> getFieldNameToDateFormatMap(SemanticSchema semanticSchema) {
return semanticSchema.getDimensions().stream()
.filter(dimension -> StringUtils.isNotBlank(dimension.getTimeFormat()))
.collect(
Collectors.toMap(
SchemaElement::getName,
value -> Optional.ofNullable(value.getTimeFormat()).orElse(""),
(k1, k2) -> k1));
}
public List<LLMReq.ElementValue> getValues(ChatQueryContext queryCtx, Long dataSetId) {
public List<LLMReq.ElementValue> getValues(@NotNull ChatQueryContext queryCtx, Long dataSetId) {
List<SchemaElementMatch> matchedElements =
queryCtx.getMapInfo().getMatchedElements(dataSetId);
if (CollectionUtils.isEmpty(matchedElements)) {
@@ -218,7 +177,8 @@ public class LLMRequestService {
return new ArrayList<>(valueMatches);
}
protected List<SchemaElement> getMatchedMetrics(ChatQueryContext queryCtx, Long dataSetId) {
protected List<SchemaElement> getMatchedMetrics(
@NotNull ChatQueryContext queryCtx, Long dataSetId) {
List<SchemaElementMatch> matchedElements =
queryCtx.getMapInfo().getMatchedElements(dataSetId);
if (CollectionUtils.isEmpty(matchedElements)) {
@@ -240,27 +200,8 @@ public class LLMRequestService {
return schemaElements;
}
protected SchemaElement getPartitionTime(ChatQueryContext queryCtx, Long dataSetId) {
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) {
return null;
}
Map<Long, DataSetSchema> dataSetSchemaMap = semanticSchema.getDataSetSchemaMap();
DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId);
return dataSetSchema.getPartitionDimension();
}
protected SchemaElement getPrimaryKey(ChatQueryContext queryCtx, Long dataSetId) {
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) {
return null;
}
Map<Long, DataSetSchema> dataSetSchemaMap = semanticSchema.getDataSetSchemaMap();
DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId);
return dataSetSchema.getPrimaryKey();
}
protected List<SchemaElement> getMatchedDimensions(ChatQueryContext queryCtx, Long dataSetId) {
protected List<SchemaElement> getMatchedDimensions(
@NotNull ChatQueryContext queryCtx, Long dataSetId) {
List<SchemaElementMatch> matchedElements =
queryCtx.getMapInfo().getMatchedElements(dataSetId);
@@ -275,4 +216,24 @@ public class LLMRequestService {
return new ArrayList<>(dimensionElements);
}
protected SchemaElement getPartitionTime(@NotNull ChatQueryContext queryCtx, Long dataSetId) {
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) {
return null;
}
Map<Long, DataSetSchema> dataSetSchemaMap = semanticSchema.getDataSetSchemaMap();
DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId);
return dataSetSchema.getPartitionDimension();
}
protected SchemaElement getPrimaryKey(@NotNull ChatQueryContext queryCtx, Long dataSetId) {
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) {
return null;
}
Map<Long, DataSetSchema> dataSetSchemaMap = semanticSchema.getDataSetSchemaMap();
DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId);
return dataSetSchema.getPrimaryKey();
}
}

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.chat.parser.llm;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
import com.tencent.supersonic.common.service.ExemplarService;
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
@@ -89,6 +90,17 @@ public class PromptHelper {
metric.getAlias().stream().forEach(a -> alias.append(a + ","));
metricStr.append(" ALIAS '" + alias + "'");
}
if (StringUtils.isNotEmpty(metric.getDataFormatType())) {
String dataFormatType = metric.getDataFormatType();
if (DataFormatTypeEnum.DECIMAL
.getName()
.equalsIgnoreCase(dataFormatType)
|| DataFormatTypeEnum.PERCENT
.getName()
.equalsIgnoreCase(dataFormatType)) {
metricStr.append(" FORMAT '" + dataFormatType + "'");
}
}
if (StringUtils.isNotEmpty(metric.getDescription())) {
metricStr.append(" COMMENT '" + metric.getDescription() + "'");
}