mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 12:37:55 +00:00
[improvement][headless]Add partitionTime and primaryKey field into the Schema part of the Text2SQL prompt. #1621
This commit is contained in:
@@ -72,6 +72,8 @@ public class LLMRequestService {
|
||||
|
||||
llmSchema.setMetrics(getMatchedMetrics(queryCtx, dataSetId));
|
||||
llmSchema.setDimensions(getMatchedDimensions(queryCtx, dataSetId));
|
||||
llmSchema.setPartitionTime(getPartitionTime(queryCtx, dataSetId));
|
||||
llmSchema.setPrimaryKey(getPrimaryKey(queryCtx, dataSetId));
|
||||
llmSchema.setTerms(getTerms(queryCtx, dataSetId));
|
||||
llmReq.setSchema(llmSchema);
|
||||
|
||||
@@ -133,12 +135,8 @@ public class LLMRequestService {
|
||||
private String getPriorKnowledge(ChatQueryContext queryContext, LLMReq.LLMSchema llmSchema) {
|
||||
StringBuilder priorKnowledgeBuilder = new StringBuilder();
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
|
||||
appendMetricPriorKnowledge(llmSchema, priorKnowledgeBuilder, semanticSchema);
|
||||
|
||||
// 处理维度字段
|
||||
appendDimensionPriorKnowledge(llmSchema, priorKnowledgeBuilder, semanticSchema);
|
||||
|
||||
return priorKnowledgeBuilder.toString();
|
||||
}
|
||||
|
||||
@@ -193,27 +191,6 @@ public class LLMRequestService {
|
||||
(k1, k2) -> k1));
|
||||
}
|
||||
|
||||
private void appendDimensionPriorKnowledge(
|
||||
LLMReq.LLMSchema llmSchema,
|
||||
StringBuilder priorKnowledgeBuilder,
|
||||
SemanticSchema semanticSchema) {
|
||||
Map<String, String> fieldNameToDateFormat = getFieldNameToDateFormatMap(semanticSchema);
|
||||
|
||||
for (SchemaElement schemaElement : llmSchema.getDimensions()) {
|
||||
String fieldName = schemaElement.getName();
|
||||
String timeFormat = fieldNameToDateFormat.get(fieldName);
|
||||
if (StringUtils.isBlank(timeFormat)) {
|
||||
continue;
|
||||
}
|
||||
if (schemaElement.containsPartitionTime()) {
|
||||
priorKnowledgeBuilder.append(
|
||||
String.format("%s 是分区时间且格式是%s", fieldName, timeFormat));
|
||||
} else {
|
||||
priorKnowledgeBuilder.append(String.format("%s 的时间格式是%s", fieldName, timeFormat));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public List<LLMReq.ElementValue> getValues(ChatQueryContext queryCtx, Long dataSetId) {
|
||||
List<SchemaElementMatch> matchedElements =
|
||||
queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||
@@ -263,32 +240,39 @@ public class LLMRequestService {
|
||||
return schemaElements;
|
||||
}
|
||||
|
||||
protected SchemaElement getPartitionTime(ChatQueryContext queryCtx, Long dataSetId) {
|
||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||
if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) {
|
||||
return null;
|
||||
}
|
||||
Map<Long, DataSetSchema> dataSetSchemaMap = semanticSchema.getDataSetSchemaMap();
|
||||
DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId);
|
||||
return dataSetSchema.getPartitionDimension();
|
||||
}
|
||||
|
||||
protected SchemaElement getPrimaryKey(ChatQueryContext queryCtx, Long dataSetId) {
|
||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||
if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) {
|
||||
return null;
|
||||
}
|
||||
Map<Long, DataSetSchema> dataSetSchemaMap = semanticSchema.getDataSetSchemaMap();
|
||||
DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId);
|
||||
return dataSetSchema.getPrimaryKey();
|
||||
}
|
||||
|
||||
protected List<SchemaElement> getMatchedDimensions(ChatQueryContext queryCtx, Long dataSetId) {
|
||||
|
||||
List<SchemaElementMatch> matchedElements =
|
||||
queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||
Set<SchemaElement> dimensionElements =
|
||||
List<SchemaElement> dimensionElements =
|
||||
matchedElements.stream()
|
||||
.filter(
|
||||
element ->
|
||||
SchemaElementType.DIMENSION.equals(
|
||||
element.getElement().getType()))
|
||||
.map(SchemaElementMatch::getElement)
|
||||
.collect(Collectors.toSet());
|
||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||
if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) {
|
||||
return new ArrayList<>(dimensionElements);
|
||||
}
|
||||
.collect(Collectors.toList());
|
||||
|
||||
Map<Long, DataSetSchema> dataSetSchemaMap = semanticSchema.getDataSetSchemaMap();
|
||||
DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId);
|
||||
if (dataSetSchema == null) {
|
||||
return new ArrayList<>(dimensionElements);
|
||||
}
|
||||
SchemaElement partitionDimension = dataSetSchema.getPartitionDimension();
|
||||
if (partitionDimension != null) {
|
||||
dimensionElements.add(partitionDimension);
|
||||
}
|
||||
return new ArrayList<>(dimensionElements);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,11 +31,11 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
+ "please convert it to a SQL query so that relevant data could be returned "
|
||||
+ "by executing the SQL query against underlying database."
|
||||
+ "\n#Rules:"
|
||||
+ "1.ALWAYS generate column specified in the `Schema`, DO NOT hallucinate."
|
||||
+ "1.ALWAYS generate columns and values specified in the `Schema`, DO NOT hallucinate."
|
||||
+ "2.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator."
|
||||
+ "3.ALWAYS calculate the absolute date range by yourself."
|
||||
+ "4.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`."
|
||||
+ "5.DO NOT miss the AGGREGATE operator of metrics, always add it if needed."
|
||||
+ "3.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`."
|
||||
+ "4.DO NOT calculate date range using functions."
|
||||
+ "5.DO NOT miss the AGGREGATE operator of metrics, always add it as needed."
|
||||
+ "6.ONLY respond with the converted SQL statement."
|
||||
+ "\n#Exemplars:\n{{exemplar}}"
|
||||
+ "Question:{{question}},Schema:{{schema}},SideInfo:{{information}},SQL:";
|
||||
|
||||
@@ -133,10 +133,28 @@ public class PromptHelper {
|
||||
values.add(valueStr.toString());
|
||||
});
|
||||
|
||||
String template = "Table=[%s], Metrics=[%s], Dimensions=[%s], Values=[%s]";
|
||||
String partitionTimeStr = "";
|
||||
if (llmReq.getSchema().getPartitionTime() != null) {
|
||||
partitionTimeStr =
|
||||
String.format(
|
||||
"%s FORMAT '%s'",
|
||||
llmReq.getSchema().getPartitionTime().getName(),
|
||||
llmReq.getSchema().getPartitionTime().getTimeFormat());
|
||||
}
|
||||
|
||||
String primaryKeyStr = "";
|
||||
if (llmReq.getSchema().getPrimaryKey() != null) {
|
||||
primaryKeyStr = String.format("%s", llmReq.getSchema().getPrimaryKey().getName());
|
||||
}
|
||||
|
||||
String template =
|
||||
"Table=[%s], PartitionTimeField=[%s], PrimaryKeyField=[%s], "
|
||||
+ "Metrics=[%s], Dimensions=[%s], Values=[%s]";
|
||||
return String.format(
|
||||
template,
|
||||
tableStr,
|
||||
partitionTimeStr,
|
||||
primaryKeyStr,
|
||||
String.join(",", metrics),
|
||||
String.join(",", dimensions),
|
||||
String.join(",", values));
|
||||
|
||||
@@ -7,11 +7,8 @@ import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import lombok.Data;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Data
|
||||
public class LLMReq {
|
||||
@@ -38,24 +35,9 @@ public class LLMReq {
|
||||
private List<String> fieldNameList;
|
||||
private List<SchemaElement> metrics;
|
||||
private List<SchemaElement> dimensions;
|
||||
private SchemaElement partitionTime;
|
||||
private SchemaElement primaryKey;
|
||||
private List<Term> terms;
|
||||
|
||||
public List<String> getFieldNameList() {
|
||||
List<String> fieldNameList = new ArrayList<>();
|
||||
if (CollectionUtils.isNotEmpty(metrics)) {
|
||||
fieldNameList.addAll(
|
||||
metrics.stream()
|
||||
.map(metric -> metric.getName())
|
||||
.collect(Collectors.toList()));
|
||||
}
|
||||
if (CollectionUtils.isNotEmpty(dimensions)) {
|
||||
fieldNameList.addAll(
|
||||
dimensions.stream()
|
||||
.map(metric -> metric.getName())
|
||||
.collect(Collectors.toList()));
|
||||
}
|
||||
return fieldNameList;
|
||||
}
|
||||
}
|
||||
|
||||
@Data
|
||||
|
||||
@@ -170,7 +170,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
}
|
||||
for (Entry<Long, List<SchemaElementMatch>> entry : id2Values.entrySet()) {
|
||||
SchemaElement dimension = semanticSchema.getElement(entity, entry.getKey());
|
||||
if (dimension.containsPartitionTime()) {
|
||||
if (dimension.isPartitionTime()) {
|
||||
continue;
|
||||
}
|
||||
if (entry.getValue().size() == 1) {
|
||||
|
||||
@@ -61,9 +61,7 @@ public class QueryReqBuilder {
|
||||
addDateDimension(parseInfo);
|
||||
|
||||
if (isDateFieldAlreadyPresent(parseInfo, getDateField(parseInfo.getDateInfo()))) {
|
||||
parseInfo
|
||||
.getDimensions()
|
||||
.removeIf(schemaElement -> schemaElement.containsPartitionTime());
|
||||
parseInfo.getDimensions().removeIf(schemaElement -> schemaElement.isPartitionTime());
|
||||
}
|
||||
queryStructReq.setGroups(
|
||||
parseInfo.getDimensions().stream()
|
||||
|
||||
Reference in New Issue
Block a user