mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-13 04:57:28 +00:00
[improvement][headless]Add partitionTime and primaryKey field into the Schema part of the Text2SQL prompt. #1621
This commit is contained in:
@@ -134,13 +134,21 @@ public class DataSetSchema {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public boolean containsPartitionDimensions() {
|
public boolean containsPartitionDimensions() {
|
||||||
return dimensions.stream().anyMatch(SchemaElement::containsPartitionTime);
|
return dimensions.stream().anyMatch(SchemaElement::isPartitionTime);
|
||||||
}
|
}
|
||||||
|
|
||||||
public SchemaElement getPartitionDimension() {
|
public SchemaElement getPartitionDimension() {
|
||||||
for (SchemaElement dimension : dimensions) {
|
for (SchemaElement dimension : dimensions) {
|
||||||
String partitionTimeFormat = dimension.getPartitionTimeFormat();
|
if (dimension.isPartitionTime()) {
|
||||||
if (StringUtils.isNotBlank(partitionTimeFormat)) {
|
return dimension;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
public SchemaElement getPrimaryKey() {
|
||||||
|
for (SchemaElement dimension : dimensions) {
|
||||||
|
if (dimension.isPrimaryKey()) {
|
||||||
return dimension;
|
return dimension;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ public class SchemaElement implements Serializable {
|
|||||||
return Objects.hashCode(dataSetId, id, name, bizName, type);
|
return Objects.hashCode(dataSetId, id, name, bizName, type);
|
||||||
}
|
}
|
||||||
|
|
||||||
public boolean containsPartitionTime() {
|
public boolean isPartitionTime() {
|
||||||
if (MapUtils.isEmpty(extInfo)) {
|
if (MapUtils.isEmpty(extInfo)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@@ -78,6 +78,21 @@ public class SchemaElement implements Serializable {
|
|||||||
return DimensionType.isPartitionTime(dimensionTYpe);
|
return DimensionType.isPartitionTime(dimensionTYpe);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public boolean isPrimaryKey() {
|
||||||
|
if (MapUtils.isEmpty(extInfo)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
Object o = extInfo.get(DimensionConstants.DIMENSION_TYPE);
|
||||||
|
DimensionType dimensionTYpe = null;
|
||||||
|
if (o instanceof DimensionType) {
|
||||||
|
dimensionTYpe = (DimensionType) o;
|
||||||
|
}
|
||||||
|
if (o instanceof String) {
|
||||||
|
dimensionTYpe = DimensionType.valueOf((String) o);
|
||||||
|
}
|
||||||
|
return DimensionType.isIdentity(dimensionTYpe);
|
||||||
|
}
|
||||||
|
|
||||||
public String getTimeFormat() {
|
public String getTimeFormat() {
|
||||||
if (MapUtils.isEmpty(extInfo)) {
|
if (MapUtils.isEmpty(extInfo)) {
|
||||||
return null;
|
return null;
|
||||||
@@ -87,7 +102,7 @@ public class SchemaElement implements Serializable {
|
|||||||
|
|
||||||
public String getPartitionTimeFormat() {
|
public String getPartitionTimeFormat() {
|
||||||
String timeFormat = getTimeFormat();
|
String timeFormat = getTimeFormat();
|
||||||
if (StringUtils.isNotBlank(timeFormat) && containsPartitionTime()) {
|
if (StringUtils.isNotBlank(timeFormat) && isPartitionTime()) {
|
||||||
return timeFormat;
|
return timeFormat;
|
||||||
}
|
}
|
||||||
return "";
|
return "";
|
||||||
|
|||||||
@@ -21,4 +21,8 @@ public enum DimensionType {
|
|||||||
public static boolean isPartitionTime(DimensionType type) {
|
public static boolean isPartitionTime(DimensionType type) {
|
||||||
return type == partition_time;
|
return type == partition_time;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static boolean isIdentity(DimensionType type) {
|
||||||
|
return type == identify;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -72,6 +72,8 @@ public class LLMRequestService {
|
|||||||
|
|
||||||
llmSchema.setMetrics(getMatchedMetrics(queryCtx, dataSetId));
|
llmSchema.setMetrics(getMatchedMetrics(queryCtx, dataSetId));
|
||||||
llmSchema.setDimensions(getMatchedDimensions(queryCtx, dataSetId));
|
llmSchema.setDimensions(getMatchedDimensions(queryCtx, dataSetId));
|
||||||
|
llmSchema.setPartitionTime(getPartitionTime(queryCtx, dataSetId));
|
||||||
|
llmSchema.setPrimaryKey(getPrimaryKey(queryCtx, dataSetId));
|
||||||
llmSchema.setTerms(getTerms(queryCtx, dataSetId));
|
llmSchema.setTerms(getTerms(queryCtx, dataSetId));
|
||||||
llmReq.setSchema(llmSchema);
|
llmReq.setSchema(llmSchema);
|
||||||
|
|
||||||
@@ -133,12 +135,8 @@ public class LLMRequestService {
|
|||||||
private String getPriorKnowledge(ChatQueryContext queryContext, LLMReq.LLMSchema llmSchema) {
|
private String getPriorKnowledge(ChatQueryContext queryContext, LLMReq.LLMSchema llmSchema) {
|
||||||
StringBuilder priorKnowledgeBuilder = new StringBuilder();
|
StringBuilder priorKnowledgeBuilder = new StringBuilder();
|
||||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||||
|
|
||||||
appendMetricPriorKnowledge(llmSchema, priorKnowledgeBuilder, semanticSchema);
|
appendMetricPriorKnowledge(llmSchema, priorKnowledgeBuilder, semanticSchema);
|
||||||
|
|
||||||
// 处理维度字段
|
|
||||||
appendDimensionPriorKnowledge(llmSchema, priorKnowledgeBuilder, semanticSchema);
|
|
||||||
|
|
||||||
return priorKnowledgeBuilder.toString();
|
return priorKnowledgeBuilder.toString();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -193,27 +191,6 @@ public class LLMRequestService {
|
|||||||
(k1, k2) -> k1));
|
(k1, k2) -> k1));
|
||||||
}
|
}
|
||||||
|
|
||||||
private void appendDimensionPriorKnowledge(
|
|
||||||
LLMReq.LLMSchema llmSchema,
|
|
||||||
StringBuilder priorKnowledgeBuilder,
|
|
||||||
SemanticSchema semanticSchema) {
|
|
||||||
Map<String, String> fieldNameToDateFormat = getFieldNameToDateFormatMap(semanticSchema);
|
|
||||||
|
|
||||||
for (SchemaElement schemaElement : llmSchema.getDimensions()) {
|
|
||||||
String fieldName = schemaElement.getName();
|
|
||||||
String timeFormat = fieldNameToDateFormat.get(fieldName);
|
|
||||||
if (StringUtils.isBlank(timeFormat)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (schemaElement.containsPartitionTime()) {
|
|
||||||
priorKnowledgeBuilder.append(
|
|
||||||
String.format("%s 是分区时间且格式是%s", fieldName, timeFormat));
|
|
||||||
} else {
|
|
||||||
priorKnowledgeBuilder.append(String.format("%s 的时间格式是%s", fieldName, timeFormat));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public List<LLMReq.ElementValue> getValues(ChatQueryContext queryCtx, Long dataSetId) {
|
public List<LLMReq.ElementValue> getValues(ChatQueryContext queryCtx, Long dataSetId) {
|
||||||
List<SchemaElementMatch> matchedElements =
|
List<SchemaElementMatch> matchedElements =
|
||||||
queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||||
@@ -263,32 +240,39 @@ public class LLMRequestService {
|
|||||||
return schemaElements;
|
return schemaElements;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected SchemaElement getPartitionTime(ChatQueryContext queryCtx, Long dataSetId) {
|
||||||
|
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||||
|
if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
Map<Long, DataSetSchema> dataSetSchemaMap = semanticSchema.getDataSetSchemaMap();
|
||||||
|
DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId);
|
||||||
|
return dataSetSchema.getPartitionDimension();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected SchemaElement getPrimaryKey(ChatQueryContext queryCtx, Long dataSetId) {
|
||||||
|
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||||
|
if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
Map<Long, DataSetSchema> dataSetSchemaMap = semanticSchema.getDataSetSchemaMap();
|
||||||
|
DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId);
|
||||||
|
return dataSetSchema.getPrimaryKey();
|
||||||
|
}
|
||||||
|
|
||||||
protected List<SchemaElement> getMatchedDimensions(ChatQueryContext queryCtx, Long dataSetId) {
|
protected List<SchemaElement> getMatchedDimensions(ChatQueryContext queryCtx, Long dataSetId) {
|
||||||
|
|
||||||
List<SchemaElementMatch> matchedElements =
|
List<SchemaElementMatch> matchedElements =
|
||||||
queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||||
Set<SchemaElement> dimensionElements =
|
List<SchemaElement> dimensionElements =
|
||||||
matchedElements.stream()
|
matchedElements.stream()
|
||||||
.filter(
|
.filter(
|
||||||
element ->
|
element ->
|
||||||
SchemaElementType.DIMENSION.equals(
|
SchemaElementType.DIMENSION.equals(
|
||||||
element.getElement().getType()))
|
element.getElement().getType()))
|
||||||
.map(SchemaElementMatch::getElement)
|
.map(SchemaElementMatch::getElement)
|
||||||
.collect(Collectors.toSet());
|
.collect(Collectors.toList());
|
||||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
|
||||||
if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) {
|
|
||||||
return new ArrayList<>(dimensionElements);
|
|
||||||
}
|
|
||||||
|
|
||||||
Map<Long, DataSetSchema> dataSetSchemaMap = semanticSchema.getDataSetSchemaMap();
|
|
||||||
DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId);
|
|
||||||
if (dataSetSchema == null) {
|
|
||||||
return new ArrayList<>(dimensionElements);
|
|
||||||
}
|
|
||||||
SchemaElement partitionDimension = dataSetSchema.getPartitionDimension();
|
|
||||||
if (partitionDimension != null) {
|
|
||||||
dimensionElements.add(partitionDimension);
|
|
||||||
}
|
|
||||||
return new ArrayList<>(dimensionElements);
|
return new ArrayList<>(dimensionElements);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,11 +31,11 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
|||||||
+ "please convert it to a SQL query so that relevant data could be returned "
|
+ "please convert it to a SQL query so that relevant data could be returned "
|
||||||
+ "by executing the SQL query against underlying database."
|
+ "by executing the SQL query against underlying database."
|
||||||
+ "\n#Rules:"
|
+ "\n#Rules:"
|
||||||
+ "1.ALWAYS generate column specified in the `Schema`, DO NOT hallucinate."
|
+ "1.ALWAYS generate columns and values specified in the `Schema`, DO NOT hallucinate."
|
||||||
+ "2.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator."
|
+ "2.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator."
|
||||||
+ "3.ALWAYS calculate the absolute date range by yourself."
|
+ "3.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`."
|
||||||
+ "4.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`."
|
+ "4.DO NOT calculate date range using functions."
|
||||||
+ "5.DO NOT miss the AGGREGATE operator of metrics, always add it if needed."
|
+ "5.DO NOT miss the AGGREGATE operator of metrics, always add it as needed."
|
||||||
+ "6.ONLY respond with the converted SQL statement."
|
+ "6.ONLY respond with the converted SQL statement."
|
||||||
+ "\n#Exemplars:\n{{exemplar}}"
|
+ "\n#Exemplars:\n{{exemplar}}"
|
||||||
+ "Question:{{question}},Schema:{{schema}},SideInfo:{{information}},SQL:";
|
+ "Question:{{question}},Schema:{{schema}},SideInfo:{{information}},SQL:";
|
||||||
|
|||||||
@@ -133,10 +133,28 @@ public class PromptHelper {
|
|||||||
values.add(valueStr.toString());
|
values.add(valueStr.toString());
|
||||||
});
|
});
|
||||||
|
|
||||||
String template = "Table=[%s], Metrics=[%s], Dimensions=[%s], Values=[%s]";
|
String partitionTimeStr = "";
|
||||||
|
if (llmReq.getSchema().getPartitionTime() != null) {
|
||||||
|
partitionTimeStr =
|
||||||
|
String.format(
|
||||||
|
"%s FORMAT '%s'",
|
||||||
|
llmReq.getSchema().getPartitionTime().getName(),
|
||||||
|
llmReq.getSchema().getPartitionTime().getTimeFormat());
|
||||||
|
}
|
||||||
|
|
||||||
|
String primaryKeyStr = "";
|
||||||
|
if (llmReq.getSchema().getPrimaryKey() != null) {
|
||||||
|
primaryKeyStr = String.format("%s", llmReq.getSchema().getPrimaryKey().getName());
|
||||||
|
}
|
||||||
|
|
||||||
|
String template =
|
||||||
|
"Table=[%s], PartitionTimeField=[%s], PrimaryKeyField=[%s], "
|
||||||
|
+ "Metrics=[%s], Dimensions=[%s], Values=[%s]";
|
||||||
return String.format(
|
return String.format(
|
||||||
template,
|
template,
|
||||||
tableStr,
|
tableStr,
|
||||||
|
partitionTimeStr,
|
||||||
|
primaryKeyStr,
|
||||||
String.join(",", metrics),
|
String.join(",", metrics),
|
||||||
String.join(",", dimensions),
|
String.join(",", dimensions),
|
||||||
String.join(",", values));
|
String.join(",", values));
|
||||||
|
|||||||
@@ -7,11 +7,8 @@ import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
|||||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import org.apache.commons.collections4.CollectionUtils;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class LLMReq {
|
public class LLMReq {
|
||||||
@@ -38,24 +35,9 @@ public class LLMReq {
|
|||||||
private List<String> fieldNameList;
|
private List<String> fieldNameList;
|
||||||
private List<SchemaElement> metrics;
|
private List<SchemaElement> metrics;
|
||||||
private List<SchemaElement> dimensions;
|
private List<SchemaElement> dimensions;
|
||||||
|
private SchemaElement partitionTime;
|
||||||
|
private SchemaElement primaryKey;
|
||||||
private List<Term> terms;
|
private List<Term> terms;
|
||||||
|
|
||||||
public List<String> getFieldNameList() {
|
|
||||||
List<String> fieldNameList = new ArrayList<>();
|
|
||||||
if (CollectionUtils.isNotEmpty(metrics)) {
|
|
||||||
fieldNameList.addAll(
|
|
||||||
metrics.stream()
|
|
||||||
.map(metric -> metric.getName())
|
|
||||||
.collect(Collectors.toList()));
|
|
||||||
}
|
|
||||||
if (CollectionUtils.isNotEmpty(dimensions)) {
|
|
||||||
fieldNameList.addAll(
|
|
||||||
dimensions.stream()
|
|
||||||
.map(metric -> metric.getName())
|
|
||||||
.collect(Collectors.toList()));
|
|
||||||
}
|
|
||||||
return fieldNameList;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
|||||||
@@ -170,7 +170,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
|||||||
}
|
}
|
||||||
for (Entry<Long, List<SchemaElementMatch>> entry : id2Values.entrySet()) {
|
for (Entry<Long, List<SchemaElementMatch>> entry : id2Values.entrySet()) {
|
||||||
SchemaElement dimension = semanticSchema.getElement(entity, entry.getKey());
|
SchemaElement dimension = semanticSchema.getElement(entity, entry.getKey());
|
||||||
if (dimension.containsPartitionTime()) {
|
if (dimension.isPartitionTime()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (entry.getValue().size() == 1) {
|
if (entry.getValue().size() == 1) {
|
||||||
|
|||||||
@@ -61,9 +61,7 @@ public class QueryReqBuilder {
|
|||||||
addDateDimension(parseInfo);
|
addDateDimension(parseInfo);
|
||||||
|
|
||||||
if (isDateFieldAlreadyPresent(parseInfo, getDateField(parseInfo.getDateInfo()))) {
|
if (isDateFieldAlreadyPresent(parseInfo, getDateField(parseInfo.getDateInfo()))) {
|
||||||
parseInfo
|
parseInfo.getDimensions().removeIf(schemaElement -> schemaElement.isPartitionTime());
|
||||||
.getDimensions()
|
|
||||||
.removeIf(schemaElement -> schemaElement.containsPartitionTime());
|
|
||||||
}
|
}
|
||||||
queryStructReq.setGroups(
|
queryStructReq.setGroups(
|
||||||
parseInfo.getDimensions().stream()
|
parseInfo.getDimensions().stream()
|
||||||
|
|||||||
Reference in New Issue
Block a user