[improvement][headless]Add partitionTime and primaryKey field into the Schema part of the Text2SQL prompt. #1621

This commit is contained in:
jerryjzhang
2024-09-11 21:29:37 +08:00
parent a82e3c8b1e
commit 82b5fa966a
9 changed files with 83 additions and 74 deletions

View File

@@ -72,6 +72,8 @@ public class LLMRequestService {
llmSchema.setMetrics(getMatchedMetrics(queryCtx, dataSetId));
llmSchema.setDimensions(getMatchedDimensions(queryCtx, dataSetId));
llmSchema.setPartitionTime(getPartitionTime(queryCtx, dataSetId));
llmSchema.setPrimaryKey(getPrimaryKey(queryCtx, dataSetId));
llmSchema.setTerms(getTerms(queryCtx, dataSetId));
llmReq.setSchema(llmSchema);
@@ -133,12 +135,8 @@ public class LLMRequestService {
private String getPriorKnowledge(ChatQueryContext queryContext, LLMReq.LLMSchema llmSchema) {
StringBuilder priorKnowledgeBuilder = new StringBuilder();
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
appendMetricPriorKnowledge(llmSchema, priorKnowledgeBuilder, semanticSchema);
// 处理维度字段
appendDimensionPriorKnowledge(llmSchema, priorKnowledgeBuilder, semanticSchema);
return priorKnowledgeBuilder.toString();
}
@@ -193,27 +191,6 @@ public class LLMRequestService {
(k1, k2) -> k1));
}
private void appendDimensionPriorKnowledge(
LLMReq.LLMSchema llmSchema,
StringBuilder priorKnowledgeBuilder,
SemanticSchema semanticSchema) {
Map<String, String> fieldNameToDateFormat = getFieldNameToDateFormatMap(semanticSchema);
for (SchemaElement schemaElement : llmSchema.getDimensions()) {
String fieldName = schemaElement.getName();
String timeFormat = fieldNameToDateFormat.get(fieldName);
if (StringUtils.isBlank(timeFormat)) {
continue;
}
if (schemaElement.containsPartitionTime()) {
priorKnowledgeBuilder.append(
String.format("%s 是分区时间且格式是%s", fieldName, timeFormat));
} else {
priorKnowledgeBuilder.append(String.format("%s 的时间格式是%s", fieldName, timeFormat));
}
}
}
public List<LLMReq.ElementValue> getValues(ChatQueryContext queryCtx, Long dataSetId) {
List<SchemaElementMatch> matchedElements =
queryCtx.getMapInfo().getMatchedElements(dataSetId);
@@ -263,32 +240,39 @@ public class LLMRequestService {
return schemaElements;
}
protected SchemaElement getPartitionTime(ChatQueryContext queryCtx, Long dataSetId) {
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) {
return null;
}
Map<Long, DataSetSchema> dataSetSchemaMap = semanticSchema.getDataSetSchemaMap();
DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId);
return dataSetSchema.getPartitionDimension();
}
protected SchemaElement getPrimaryKey(ChatQueryContext queryCtx, Long dataSetId) {
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) {
return null;
}
Map<Long, DataSetSchema> dataSetSchemaMap = semanticSchema.getDataSetSchemaMap();
DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId);
return dataSetSchema.getPrimaryKey();
}
protected List<SchemaElement> getMatchedDimensions(ChatQueryContext queryCtx, Long dataSetId) {
List<SchemaElementMatch> matchedElements =
queryCtx.getMapInfo().getMatchedElements(dataSetId);
Set<SchemaElement> dimensionElements =
List<SchemaElement> dimensionElements =
matchedElements.stream()
.filter(
element ->
SchemaElementType.DIMENSION.equals(
element.getElement().getType()))
.map(SchemaElementMatch::getElement)
.collect(Collectors.toSet());
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) {
return new ArrayList<>(dimensionElements);
}
.collect(Collectors.toList());
Map<Long, DataSetSchema> dataSetSchemaMap = semanticSchema.getDataSetSchemaMap();
DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId);
if (dataSetSchema == null) {
return new ArrayList<>(dimensionElements);
}
SchemaElement partitionDimension = dataSetSchema.getPartitionDimension();
if (partitionDimension != null) {
dimensionElements.add(partitionDimension);
}
return new ArrayList<>(dimensionElements);
}
}

View File

@@ -31,11 +31,11 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
+ "please convert it to a SQL query so that relevant data could be returned "
+ "by executing the SQL query against underlying database."
+ "\n#Rules:"
+ "1.ALWAYS generate column specified in the `Schema`, DO NOT hallucinate."
+ "1.ALWAYS generate columns and values specified in the `Schema`, DO NOT hallucinate."
+ "2.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator."
+ "3.ALWAYS calculate the absolute date range by yourself."
+ "4.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`."
+ "5.DO NOT miss the AGGREGATE operator of metrics, always add it if needed."
+ "3.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`."
+ "4.DO NOT calculate date range using functions."
+ "5.DO NOT miss the AGGREGATE operator of metrics, always add it as needed."
+ "6.ONLY respond with the converted SQL statement."
+ "\n#Exemplars:\n{{exemplar}}"
+ "Question:{{question}},Schema:{{schema}},SideInfo:{{information}},SQL:";

View File

@@ -133,10 +133,28 @@ public class PromptHelper {
values.add(valueStr.toString());
});
String template = "Table=[%s], Metrics=[%s], Dimensions=[%s], Values=[%s]";
String partitionTimeStr = "";
if (llmReq.getSchema().getPartitionTime() != null) {
partitionTimeStr =
String.format(
"%s FORMAT '%s'",
llmReq.getSchema().getPartitionTime().getName(),
llmReq.getSchema().getPartitionTime().getTimeFormat());
}
String primaryKeyStr = "";
if (llmReq.getSchema().getPrimaryKey() != null) {
primaryKeyStr = String.format("%s", llmReq.getSchema().getPrimaryKey().getName());
}
String template =
"Table=[%s], PartitionTimeField=[%s], PrimaryKeyField=[%s], "
+ "Metrics=[%s], Dimensions=[%s], Values=[%s]";
return String.format(
template,
tableStr,
partitionTimeStr,
primaryKeyStr,
String.join(",", metrics),
String.join(",", dimensions),
String.join(",", values));

View File

@@ -7,11 +7,8 @@ import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import lombok.Data;
import org.apache.commons.collections4.CollectionUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
@Data
public class LLMReq {
@@ -38,24 +35,9 @@ public class LLMReq {
private List<String> fieldNameList;
private List<SchemaElement> metrics;
private List<SchemaElement> dimensions;
private SchemaElement partitionTime;
private SchemaElement primaryKey;
private List<Term> terms;
public List<String> getFieldNameList() {
List<String> fieldNameList = new ArrayList<>();
if (CollectionUtils.isNotEmpty(metrics)) {
fieldNameList.addAll(
metrics.stream()
.map(metric -> metric.getName())
.collect(Collectors.toList()));
}
if (CollectionUtils.isNotEmpty(dimensions)) {
fieldNameList.addAll(
dimensions.stream()
.map(metric -> metric.getName())
.collect(Collectors.toList()));
}
return fieldNameList;
}
}
@Data

View File

@@ -170,7 +170,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
}
for (Entry<Long, List<SchemaElementMatch>> entry : id2Values.entrySet()) {
SchemaElement dimension = semanticSchema.getElement(entity, entry.getKey());
if (dimension.containsPartitionTime()) {
if (dimension.isPartitionTime()) {
continue;
}
if (entry.getValue().size() == 1) {

View File

@@ -61,9 +61,7 @@ public class QueryReqBuilder {
addDateDimension(parseInfo);
if (isDateFieldAlreadyPresent(parseInfo, getDateField(parseInfo.getDateInfo()))) {
parseInfo
.getDimensions()
.removeIf(schemaElement -> schemaElement.containsPartitionTime());
parseInfo.getDimensions().removeIf(schemaElement -> schemaElement.isPartitionTime());
}
queryStructReq.setGroups(
parseInfo.getDimensions().stream()