mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 02:46:56 +00:00
[improvement][headless]Add metricFormat field to metric metadata of the Text2SQL prompt. #1621
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
@@ -14,8 +13,8 @@ import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
@@ -26,7 +25,6 @@ import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@@ -69,7 +67,6 @@ public class LLMRequestService {
|
||||
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
|
||||
llmSchema.setDataSetId(dataSetId);
|
||||
llmSchema.setDataSetName(dataSetIdToName.get(dataSetId));
|
||||
|
||||
llmSchema.setMetrics(getMatchedMetrics(queryCtx, dataSetId));
|
||||
llmSchema.setDimensions(getMatchedDimensions(queryCtx, dataSetId));
|
||||
llmSchema.setPartitionTime(getPartitionTime(queryCtx, dataSetId));
|
||||
@@ -77,9 +74,6 @@ public class LLMRequestService {
|
||||
llmSchema.setTerms(getTerms(queryCtx, dataSetId));
|
||||
llmReq.setSchema(llmSchema);
|
||||
|
||||
String priorKnowledge = getPriorKnowledge(queryCtx, llmSchema);
|
||||
llmReq.setPriorExts(priorKnowledge);
|
||||
|
||||
List<LLMReq.ElementValue> linking = new ArrayList<>();
|
||||
boolean linkingValueEnabled =
|
||||
Boolean.valueOf(parserConfig.getParameterValue(PARSER_LINKING_VALUE_ENABLE));
|
||||
@@ -132,14 +126,6 @@ public class LLMRequestService {
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private String getPriorKnowledge(ChatQueryContext queryContext, LLMReq.LLMSchema llmSchema) {
|
||||
StringBuilder priorKnowledgeBuilder = new StringBuilder();
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
appendMetricPriorKnowledge(llmSchema, priorKnowledgeBuilder, semanticSchema);
|
||||
|
||||
return priorKnowledgeBuilder.toString();
|
||||
}
|
||||
|
||||
private Map<String, String> getFieldNameToDataFormatTypeMap(SemanticSchema semanticSchema) {
|
||||
return semanticSchema.getMetrics().stream()
|
||||
.filter(metric -> Objects.nonNull(metric.getDataFormatType()))
|
||||
@@ -164,34 +150,7 @@ public class LLMRequestService {
|
||||
(existing, replacement) -> existing));
|
||||
}
|
||||
|
||||
private void appendMetricPriorKnowledge(
|
||||
LLMReq.LLMSchema llmSchema,
|
||||
StringBuilder priorKnowledgeBuilder,
|
||||
SemanticSchema semanticSchema) {
|
||||
Map<String, String> fieldNameToDataFormatType =
|
||||
getFieldNameToDataFormatTypeMap(semanticSchema);
|
||||
|
||||
for (SchemaElement schemaElement : llmSchema.getMetrics()) {
|
||||
String fieldName = schemaElement.getName();
|
||||
String dataFormatType = fieldNameToDataFormatType.get(fieldName);
|
||||
if (DataFormatTypeEnum.DECIMAL.getName().equalsIgnoreCase(dataFormatType)
|
||||
|| DataFormatTypeEnum.PERCENT.getName().equalsIgnoreCase(dataFormatType)) {
|
||||
priorKnowledgeBuilder.append(String.format("%s的计量单位是%s; ", fieldName, "小数"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private Map<String, String> getFieldNameToDateFormatMap(SemanticSchema semanticSchema) {
|
||||
return semanticSchema.getDimensions().stream()
|
||||
.filter(dimension -> StringUtils.isNotBlank(dimension.getTimeFormat()))
|
||||
.collect(
|
||||
Collectors.toMap(
|
||||
SchemaElement::getName,
|
||||
value -> Optional.ofNullable(value.getTimeFormat()).orElse(""),
|
||||
(k1, k2) -> k1));
|
||||
}
|
||||
|
||||
public List<LLMReq.ElementValue> getValues(ChatQueryContext queryCtx, Long dataSetId) {
|
||||
public List<LLMReq.ElementValue> getValues(@NotNull ChatQueryContext queryCtx, Long dataSetId) {
|
||||
List<SchemaElementMatch> matchedElements =
|
||||
queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||
@@ -218,7 +177,8 @@ public class LLMRequestService {
|
||||
return new ArrayList<>(valueMatches);
|
||||
}
|
||||
|
||||
protected List<SchemaElement> getMatchedMetrics(ChatQueryContext queryCtx, Long dataSetId) {
|
||||
protected List<SchemaElement> getMatchedMetrics(
|
||||
@NotNull ChatQueryContext queryCtx, Long dataSetId) {
|
||||
List<SchemaElementMatch> matchedElements =
|
||||
queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||
@@ -240,27 +200,8 @@ public class LLMRequestService {
|
||||
return schemaElements;
|
||||
}
|
||||
|
||||
protected SchemaElement getPartitionTime(ChatQueryContext queryCtx, Long dataSetId) {
|
||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||
if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) {
|
||||
return null;
|
||||
}
|
||||
Map<Long, DataSetSchema> dataSetSchemaMap = semanticSchema.getDataSetSchemaMap();
|
||||
DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId);
|
||||
return dataSetSchema.getPartitionDimension();
|
||||
}
|
||||
|
||||
protected SchemaElement getPrimaryKey(ChatQueryContext queryCtx, Long dataSetId) {
|
||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||
if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) {
|
||||
return null;
|
||||
}
|
||||
Map<Long, DataSetSchema> dataSetSchemaMap = semanticSchema.getDataSetSchemaMap();
|
||||
DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId);
|
||||
return dataSetSchema.getPrimaryKey();
|
||||
}
|
||||
|
||||
protected List<SchemaElement> getMatchedDimensions(ChatQueryContext queryCtx, Long dataSetId) {
|
||||
protected List<SchemaElement> getMatchedDimensions(
|
||||
@NotNull ChatQueryContext queryCtx, Long dataSetId) {
|
||||
|
||||
List<SchemaElementMatch> matchedElements =
|
||||
queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||
@@ -275,4 +216,24 @@ public class LLMRequestService {
|
||||
|
||||
return new ArrayList<>(dimensionElements);
|
||||
}
|
||||
|
||||
protected SchemaElement getPartitionTime(@NotNull ChatQueryContext queryCtx, Long dataSetId) {
|
||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||
if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) {
|
||||
return null;
|
||||
}
|
||||
Map<Long, DataSetSchema> dataSetSchemaMap = semanticSchema.getDataSetSchemaMap();
|
||||
DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId);
|
||||
return dataSetSchema.getPartitionDimension();
|
||||
}
|
||||
|
||||
protected SchemaElement getPrimaryKey(@NotNull ChatQueryContext queryCtx, Long dataSetId) {
|
||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||
if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) {
|
||||
return null;
|
||||
}
|
||||
Map<Long, DataSetSchema> dataSetSchemaMap = semanticSchema.getDataSetSchemaMap();
|
||||
DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId);
|
||||
return dataSetSchema.getPrimaryKey();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
|
||||
import com.tencent.supersonic.common.service.ExemplarService;
|
||||
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
@@ -89,6 +90,17 @@ public class PromptHelper {
|
||||
metric.getAlias().stream().forEach(a -> alias.append(a + ","));
|
||||
metricStr.append(" ALIAS '" + alias + "'");
|
||||
}
|
||||
if (StringUtils.isNotEmpty(metric.getDataFormatType())) {
|
||||
String dataFormatType = metric.getDataFormatType();
|
||||
if (DataFormatTypeEnum.DECIMAL
|
||||
.getName()
|
||||
.equalsIgnoreCase(dataFormatType)
|
||||
|| DataFormatTypeEnum.PERCENT
|
||||
.getName()
|
||||
.equalsIgnoreCase(dataFormatType)) {
|
||||
metricStr.append(" FORMAT '" + dataFormatType + "'");
|
||||
}
|
||||
}
|
||||
if (StringUtils.isNotEmpty(metric.getDescription())) {
|
||||
metricStr.append(" COMMENT '" + metric.getDescription() + "'");
|
||||
}
|
||||
|
||||
@@ -2,49 +2,49 @@
|
||||
{
|
||||
"question": "比较jackjchen和robinlee今年以来的访问次数",
|
||||
"sideInfo": "CurrentDate=[2020-12-01],DomainTerms=[<核心用户 COMMENT '核心用户指tom和lucy'>]",
|
||||
"dbSchema": "Table=[超音数产品], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>,<访问用户数 ALIAS 'UV,访问人数,' COMMENT '访问的用户个数' AGGREGATE 'COUNT'>,<人均访问次数 ALIAS '平均访问次数,' COMMENT '每个用户平均访问的次数'>], Dimensions=[], Values[<用户='jackjchen'>,<用户='robinlee'>]",
|
||||
"dbSchema": "Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>,<访问用户数 ALIAS 'UV,访问人数,' COMMENT '访问的用户个数' AGGREGATE 'COUNT'>,<人均访问次数 ALIAS '平均访问次数,' COMMENT '每个用户平均访问的次数'>], Dimensions=[<数据日期>], Values[<用户='jackjchen'>,<用户='robinlee'>]",
|
||||
"sql": "SELECT 用户, 访问次数 FROM 超音数产品 WHERE 用户 IN ('jackjchen', 'robinlee') AND 数据日期 >= '2020-01-01' AND 数据日期 <= '2020-12-01'"
|
||||
},
|
||||
{
|
||||
"question": "超音数近12个月访问人数 按部门",
|
||||
"sideInfo": "CurrentDate=[2022-11-06]",
|
||||
"dbSchema": "Table=[超音数产品], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>,<访问用户数 ALIAS 'UV,访问人数,' COMMENT '访问的用户个数' AGGREGATE 'COUNT'>,<人均访问次数 ALIAS '平均访问次数,' COMMENT '每个用户平均访问的次数'>], Dimensions=[<部门>], Values=[]",
|
||||
"dbSchema": "Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>,<访问用户数 ALIAS 'UV,访问人数,' COMMENT '访问的用户个数' AGGREGATE 'COUNT'>,<人均访问次数 ALIAS '平均访问次数,' COMMENT '每个用户平均访问的次数'>], Dimensions=[<部门>,<数据日期>], Values=[]",
|
||||
"sql": "SELECT 部门, 数据日期, 访问人数 FROM 超音数产品 WHERE 数据日期 >= '2021-11-06' AND 数据日期 <= '2022-11-06'"
|
||||
},
|
||||
{
|
||||
"question": "超音数过去90天美术部、技术研发部的访问时长",
|
||||
"sideInfo": "CurrentDate=[2023-04-21]",
|
||||
"dbSchema": "Table=[超音数产品], Metrics=[<访问时长 COMMENT '一段时间内用户的访问时长' AGGREGATE 'SUM'>], Dimensions=[], Values=[<部门='美术部'>,<部门='技术研发部'>]",
|
||||
"dbSchema": "Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问时长 COMMENT '一段时间内用户的访问时长' AGGREGATE 'SUM'>], Dimensions=[<数据日期>], Values=[<部门='美术部'>,<部门='技术研发部'>]",
|
||||
"sql": "SELECT 部门, 访问时长 FROM 超音数产品 WHERE 部门 IN ('美术部', '技术研发部') AND 数据日期 >= '2023-01-20' AND 数据日期 <= '2023-04-21'"
|
||||
},
|
||||
{
|
||||
"question": "超音数访问时长小于1小时,且来自美术部的用户是哪些",
|
||||
"sideInfo": "CurrentDate=[2023-07-31],DomainTerms=[<核心用户 COMMENT '用户为tom和lucy'>]",
|
||||
"dbSchema": "Table:[超音数产品], Metrics:[<访问时长 COMMENT '一段时间内用户的访问时长' AGGREGATE 'SUM'>], Dimensions:[<用户>], Values:[<部门='美术部'>]",
|
||||
"dbSchema": "Table:[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics:[<访问时长 COMMENT '一段时间内用户的访问时长' AGGREGATE 'SUM'>], Dimensions:[<用户>,<数据日期>], Values:[<部门='美术部'>]",
|
||||
"sql": "SELECT 用户 FROM 超音数产品 WHERE 部门 = '美术部' AND 访问时长 < 1"
|
||||
},
|
||||
{
|
||||
"question": "超音数本月pv最高的用户有哪些",
|
||||
"sideInfo": "CurrentDate=[2023-08-31],DomainTerms=[<核心用户 COMMENT '用户为tom和lucy'>]",
|
||||
"dbSchema": "Table=[超音数产品], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>], Dimensions=[<用户>], Values=[]",
|
||||
"dbSchema": "Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>], Dimensions=[<用户>,<数据日期>], Values=[]",
|
||||
"sql": "SELECT 用户 FROM 超音数产品 WHERE 数据日期 >= '2023-08-01' AND 数据日期 <= '2023-08-31' ORDER BY 访问次数 DESC LIMIT 1"
|
||||
},
|
||||
{
|
||||
"question": "超音数访问次数大于1k的部门是哪些",
|
||||
"sideInfo": "CurrentDate=[2023-09-14]",
|
||||
"dbSchema": "Table=[超音数产品], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>], Dimensions=[<部门>], Values=[]",
|
||||
"dbSchema": "Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>], Dimensions=[<部门>,<数据日期>], Values=[]",
|
||||
"sql": "SELECT 部门 FROM 超音数产品 WHERE 访问次数 > 1000"
|
||||
},
|
||||
{
|
||||
"question": "过去半个月核心用户的访问次数",
|
||||
"sideInfo": "CurrentDate=[2023-09-15],DomainTerms=[<核心用户 COMMENT '用户为alice'>]",
|
||||
"dbSchema": "Table=[超音数产品], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>], Dimensions=[<部门>], Values=[]",
|
||||
"dbSchema": "Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>], Dimensions=[<部门>,<数据日期>], Values=[]",
|
||||
"sql": "SELECT 用户,SUM(访问次数) FROM 超音数产品 WHERE 用户='alice' AND 数据日期 >= '2023-09-01' AND 数据日期 <= '2023-09-15'"
|
||||
},
|
||||
{
|
||||
"question": "过去半个月忠实用户有哪一些",
|
||||
"sideInfo": "CurrentDate=[2023-09-15],DomainTerms=[<忠实用户 COMMENT '一段时间内总访问次数大于100的用户'>]",
|
||||
"dbSchema": "Table=[超音数产品], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>], Dimensions=[<用户>], Values=[]",
|
||||
"dbSchema": "Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>], Dimensions=[<用户>,<数据日期>], Values=[]",
|
||||
"sql": "SELECT 用户 FROM 超音数产品 WHERE 数据日期 >= '2023-09-01' AND 数据日期 <= '2023-09-15' GROUP BY 用户 HAVING SUM(访问次数) > 100"
|
||||
}
|
||||
]
|
||||
Reference in New Issue
Block a user