From 693356e46a96eddeb428a7ad18b2e317a6f6882c Mon Sep 17 00:00:00 2001 From: jerryjzhang Date: Thu, 12 Sep 2024 18:13:44 +0800 Subject: [PATCH] [improvement][headless]Add `metricFormat` field to metric metadata of the Text2SQL prompt. #1621 --- .../chat/parser/llm/LLMRequestService.java | 91 ++++++------------- .../chat/parser/llm/PromptHelper.java | 12 +++ .../src/test/resources/s2-exemplar.json | 16 ++-- 3 files changed, 46 insertions(+), 73 deletions(-) diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java index 8cf0bc0ac..9be8431bf 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java @@ -1,6 +1,5 @@ package com.tencent.supersonic.headless.chat.parser.llm; -import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum; import com.tencent.supersonic.common.util.DateUtils; import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.SchemaElement; @@ -14,8 +13,8 @@ import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.headless.chat.utils.ComponentFactory; import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; +import org.jetbrains.annotations.NotNull; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import org.springframework.util.CollectionUtils; @@ -26,7 +25,6 @@ import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -69,7 +67,6 @@ public class LLMRequestService { LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema(); llmSchema.setDataSetId(dataSetId); llmSchema.setDataSetName(dataSetIdToName.get(dataSetId)); - llmSchema.setMetrics(getMatchedMetrics(queryCtx, dataSetId)); llmSchema.setDimensions(getMatchedDimensions(queryCtx, dataSetId)); llmSchema.setPartitionTime(getPartitionTime(queryCtx, dataSetId)); @@ -77,9 +74,6 @@ public class LLMRequestService { llmSchema.setTerms(getTerms(queryCtx, dataSetId)); llmReq.setSchema(llmSchema); - String priorKnowledge = getPriorKnowledge(queryCtx, llmSchema); - llmReq.setPriorExts(priorKnowledge); - List linking = new ArrayList<>(); boolean linkingValueEnabled = Boolean.valueOf(parserConfig.getParameterValue(PARSER_LINKING_VALUE_ENABLE)); @@ -132,14 +126,6 @@ public class LLMRequestService { .collect(Collectors.toList()); } - private String getPriorKnowledge(ChatQueryContext queryContext, LLMReq.LLMSchema llmSchema) { - StringBuilder priorKnowledgeBuilder = new StringBuilder(); - SemanticSchema semanticSchema = queryContext.getSemanticSchema(); - appendMetricPriorKnowledge(llmSchema, priorKnowledgeBuilder, semanticSchema); - - return priorKnowledgeBuilder.toString(); - } - private Map getFieldNameToDataFormatTypeMap(SemanticSchema semanticSchema) { return semanticSchema.getMetrics().stream() .filter(metric -> Objects.nonNull(metric.getDataFormatType())) @@ -164,34 +150,7 @@ public class LLMRequestService { (existing, replacement) -> existing)); } - private void appendMetricPriorKnowledge( - LLMReq.LLMSchema llmSchema, - StringBuilder priorKnowledgeBuilder, - SemanticSchema semanticSchema) { - Map fieldNameToDataFormatType = - getFieldNameToDataFormatTypeMap(semanticSchema); - - for (SchemaElement schemaElement : llmSchema.getMetrics()) { - String fieldName = schemaElement.getName(); - String dataFormatType = fieldNameToDataFormatType.get(fieldName); - if (DataFormatTypeEnum.DECIMAL.getName().equalsIgnoreCase(dataFormatType) - || DataFormatTypeEnum.PERCENT.getName().equalsIgnoreCase(dataFormatType)) { - priorKnowledgeBuilder.append(String.format("%s的计量单位是%s; ", fieldName, "小数")); - } - } - } - - private Map getFieldNameToDateFormatMap(SemanticSchema semanticSchema) { - return semanticSchema.getDimensions().stream() - .filter(dimension -> StringUtils.isNotBlank(dimension.getTimeFormat())) - .collect( - Collectors.toMap( - SchemaElement::getName, - value -> Optional.ofNullable(value.getTimeFormat()).orElse(""), - (k1, k2) -> k1)); - } - - public List getValues(ChatQueryContext queryCtx, Long dataSetId) { + public List getValues(@NotNull ChatQueryContext queryCtx, Long dataSetId) { List matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId); if (CollectionUtils.isEmpty(matchedElements)) { @@ -218,7 +177,8 @@ public class LLMRequestService { return new ArrayList<>(valueMatches); } - protected List getMatchedMetrics(ChatQueryContext queryCtx, Long dataSetId) { + protected List getMatchedMetrics( + @NotNull ChatQueryContext queryCtx, Long dataSetId) { List matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId); if (CollectionUtils.isEmpty(matchedElements)) { @@ -240,27 +200,8 @@ public class LLMRequestService { return schemaElements; } - protected SchemaElement getPartitionTime(ChatQueryContext queryCtx, Long dataSetId) { - SemanticSchema semanticSchema = queryCtx.getSemanticSchema(); - if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) { - return null; - } - Map dataSetSchemaMap = semanticSchema.getDataSetSchemaMap(); - DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId); - return dataSetSchema.getPartitionDimension(); - } - - protected SchemaElement getPrimaryKey(ChatQueryContext queryCtx, Long dataSetId) { - SemanticSchema semanticSchema = queryCtx.getSemanticSchema(); - if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) { - return null; - } - Map dataSetSchemaMap = semanticSchema.getDataSetSchemaMap(); - DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId); - return dataSetSchema.getPrimaryKey(); - } - - protected List getMatchedDimensions(ChatQueryContext queryCtx, Long dataSetId) { + protected List getMatchedDimensions( + @NotNull ChatQueryContext queryCtx, Long dataSetId) { List matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId); @@ -275,4 +216,24 @@ public class LLMRequestService { return new ArrayList<>(dimensionElements); } + + protected SchemaElement getPartitionTime(@NotNull ChatQueryContext queryCtx, Long dataSetId) { + SemanticSchema semanticSchema = queryCtx.getSemanticSchema(); + if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) { + return null; + } + Map dataSetSchemaMap = semanticSchema.getDataSetSchemaMap(); + DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId); + return dataSetSchema.getPartitionDimension(); + } + + protected SchemaElement getPrimaryKey(@NotNull ChatQueryContext queryCtx, Long dataSetId) { + SemanticSchema semanticSchema = queryCtx.getSemanticSchema(); + if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) { + return null; + } + Map dataSetSchemaMap = semanticSchema.getDataSetSchemaMap(); + DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId); + return dataSetSchema.getPrimaryKey(); + } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java index ce88d2f11..715f145da 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java @@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.chat.parser.llm; import com.google.common.collect.Lists; import com.tencent.supersonic.common.pojo.Text2SQLExemplar; +import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum; import com.tencent.supersonic.common.service.ExemplarService; import com.tencent.supersonic.headless.chat.parser.ParserConfig; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; @@ -89,6 +90,17 @@ public class PromptHelper { metric.getAlias().stream().forEach(a -> alias.append(a + ",")); metricStr.append(" ALIAS '" + alias + "'"); } + if (StringUtils.isNotEmpty(metric.getDataFormatType())) { + String dataFormatType = metric.getDataFormatType(); + if (DataFormatTypeEnum.DECIMAL + .getName() + .equalsIgnoreCase(dataFormatType) + || DataFormatTypeEnum.PERCENT + .getName() + .equalsIgnoreCase(dataFormatType)) { + metricStr.append(" FORMAT '" + dataFormatType + "'"); + } + } if (StringUtils.isNotEmpty(metric.getDescription())) { metricStr.append(" COMMENT '" + metric.getDescription() + "'"); } diff --git a/launchers/standalone/src/test/resources/s2-exemplar.json b/launchers/standalone/src/test/resources/s2-exemplar.json index d9afdf04a..d73c73d8b 100644 --- a/launchers/standalone/src/test/resources/s2-exemplar.json +++ b/launchers/standalone/src/test/resources/s2-exemplar.json @@ -2,49 +2,49 @@ { "question": "比较jackjchen和robinlee今年以来的访问次数", "sideInfo": "CurrentDate=[2020-12-01],DomainTerms=[<核心用户 COMMENT '核心用户指tom和lucy'>]", - "dbSchema": "Table=[超音数产品], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>,<访问用户数 ALIAS 'UV,访问人数,' COMMENT '访问的用户个数' AGGREGATE 'COUNT'>,<人均访问次数 ALIAS '平均访问次数,' COMMENT '每个用户平均访问的次数'>], Dimensions=[], Values[<用户='jackjchen'>,<用户='robinlee'>]", + "dbSchema": "Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>,<访问用户数 ALIAS 'UV,访问人数,' COMMENT '访问的用户个数' AGGREGATE 'COUNT'>,<人均访问次数 ALIAS '平均访问次数,' COMMENT '每个用户平均访问的次数'>], Dimensions=[<数据日期>], Values[<用户='jackjchen'>,<用户='robinlee'>]", "sql": "SELECT 用户, 访问次数 FROM 超音数产品 WHERE 用户 IN ('jackjchen', 'robinlee') AND 数据日期 >= '2020-01-01' AND 数据日期 <= '2020-12-01'" }, { "question": "超音数近12个月访问人数 按部门", "sideInfo": "CurrentDate=[2022-11-06]", - "dbSchema": "Table=[超音数产品], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>,<访问用户数 ALIAS 'UV,访问人数,' COMMENT '访问的用户个数' AGGREGATE 'COUNT'>,<人均访问次数 ALIAS '平均访问次数,' COMMENT '每个用户平均访问的次数'>], Dimensions=[<部门>], Values=[]", + "dbSchema": "Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>,<访问用户数 ALIAS 'UV,访问人数,' COMMENT '访问的用户个数' AGGREGATE 'COUNT'>,<人均访问次数 ALIAS '平均访问次数,' COMMENT '每个用户平均访问的次数'>], Dimensions=[<部门>,<数据日期>], Values=[]", "sql": "SELECT 部门, 数据日期, 访问人数 FROM 超音数产品 WHERE 数据日期 >= '2021-11-06' AND 数据日期 <= '2022-11-06'" }, { "question": "超音数过去90天美术部、技术研发部的访问时长", "sideInfo": "CurrentDate=[2023-04-21]", - "dbSchema": "Table=[超音数产品], Metrics=[<访问时长 COMMENT '一段时间内用户的访问时长' AGGREGATE 'SUM'>], Dimensions=[], Values=[<部门='美术部'>,<部门='技术研发部'>]", + "dbSchema": "Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问时长 COMMENT '一段时间内用户的访问时长' AGGREGATE 'SUM'>], Dimensions=[<数据日期>], Values=[<部门='美术部'>,<部门='技术研发部'>]", "sql": "SELECT 部门, 访问时长 FROM 超音数产品 WHERE 部门 IN ('美术部', '技术研发部') AND 数据日期 >= '2023-01-20' AND 数据日期 <= '2023-04-21'" }, { "question": "超音数访问时长小于1小时,且来自美术部的用户是哪些", "sideInfo": "CurrentDate=[2023-07-31],DomainTerms=[<核心用户 COMMENT '用户为tom和lucy'>]", - "dbSchema": "Table:[超音数产品], Metrics:[<访问时长 COMMENT '一段时间内用户的访问时长' AGGREGATE 'SUM'>], Dimensions:[<用户>], Values:[<部门='美术部'>]", + "dbSchema": "Table:[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics:[<访问时长 COMMENT '一段时间内用户的访问时长' AGGREGATE 'SUM'>], Dimensions:[<用户>,<数据日期>], Values:[<部门='美术部'>]", "sql": "SELECT 用户 FROM 超音数产品 WHERE 部门 = '美术部' AND 访问时长 < 1" }, { "question": "超音数本月pv最高的用户有哪些", "sideInfo": "CurrentDate=[2023-08-31],DomainTerms=[<核心用户 COMMENT '用户为tom和lucy'>]", - "dbSchema": "Table=[超音数产品], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>], Dimensions=[<用户>], Values=[]", + "dbSchema": "Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>], Dimensions=[<用户>,<数据日期>], Values=[]", "sql": "SELECT 用户 FROM 超音数产品 WHERE 数据日期 >= '2023-08-01' AND 数据日期 <= '2023-08-31' ORDER BY 访问次数 DESC LIMIT 1" }, { "question": "超音数访问次数大于1k的部门是哪些", "sideInfo": "CurrentDate=[2023-09-14]", - "dbSchema": "Table=[超音数产品], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>], Dimensions=[<部门>], Values=[]", + "dbSchema": "Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>], Dimensions=[<部门>,<数据日期>], Values=[]", "sql": "SELECT 部门 FROM 超音数产品 WHERE 访问次数 > 1000" }, { "question": "过去半个月核心用户的访问次数", "sideInfo": "CurrentDate=[2023-09-15],DomainTerms=[<核心用户 COMMENT '用户为alice'>]", - "dbSchema": "Table=[超音数产品], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>], Dimensions=[<部门>], Values=[]", + "dbSchema": "Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>], Dimensions=[<部门>,<数据日期>], Values=[]", "sql": "SELECT 用户,SUM(访问次数) FROM 超音数产品 WHERE 用户='alice' AND 数据日期 >= '2023-09-01' AND 数据日期 <= '2023-09-15'" }, { "question": "过去半个月忠实用户有哪一些", "sideInfo": "CurrentDate=[2023-09-15],DomainTerms=[<忠实用户 COMMENT '一段时间内总访问次数大于100的用户'>]", - "dbSchema": "Table=[超音数产品], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>], Dimensions=[<用户>], Values=[]", + "dbSchema": "Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>], Dimensions=[<用户>,<数据日期>], Values=[]", "sql": "SELECT 用户 FROM 超音数产品 WHERE 数据日期 >= '2023-09-01' AND 数据日期 <= '2023-09-15' GROUP BY 用户 HAVING SUM(访问次数) > 100" } ] \ No newline at end of file