diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DataSetSchema.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DataSetSchema.java index 5f2d79b26..fc46ff626 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DataSetSchema.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DataSetSchema.java @@ -134,13 +134,21 @@ public class DataSetSchema { } public boolean containsPartitionDimensions() { - return dimensions.stream().anyMatch(SchemaElement::containsPartitionTime); + return dimensions.stream().anyMatch(SchemaElement::isPartitionTime); } public SchemaElement getPartitionDimension() { for (SchemaElement dimension : dimensions) { - String partitionTimeFormat = dimension.getPartitionTimeFormat(); - if (StringUtils.isNotBlank(partitionTimeFormat)) { + if (dimension.isPartitionTime()) { + return dimension; + } + } + return null; + } + + public SchemaElement getPrimaryKey() { + for (SchemaElement dimension : dimensions) { + if (dimension.isPrimaryKey()) { return dimension; } } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java index fbe001269..aef5d66c1 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java @@ -63,7 +63,7 @@ public class SchemaElement implements Serializable { return Objects.hashCode(dataSetId, id, name, bizName, type); } - public boolean containsPartitionTime() { + public boolean isPartitionTime() { if (MapUtils.isEmpty(extInfo)) { return false; } @@ -78,6 +78,21 @@ public class SchemaElement implements Serializable { return DimensionType.isPartitionTime(dimensionTYpe); } + public boolean isPrimaryKey() { + if (MapUtils.isEmpty(extInfo)) { + return false; + } + Object o = extInfo.get(DimensionConstants.DIMENSION_TYPE); + DimensionType dimensionTYpe = null; + if (o instanceof DimensionType) { + dimensionTYpe = (DimensionType) o; + } + if (o instanceof String) { + dimensionTYpe = DimensionType.valueOf((String) o); + } + return DimensionType.isIdentity(dimensionTYpe); + } + public String getTimeFormat() { if (MapUtils.isEmpty(extInfo)) { return null; @@ -87,7 +102,7 @@ public class SchemaElement implements Serializable { public String getPartitionTimeFormat() { String timeFormat = getTimeFormat(); - if (StringUtils.isNotBlank(timeFormat) && containsPartitionTime()) { + if (StringUtils.isNotBlank(timeFormat) && isPartitionTime()) { return timeFormat; } return ""; diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/DimensionType.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/DimensionType.java index d1a7d5955..ba320e389 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/DimensionType.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/DimensionType.java @@ -21,4 +21,8 @@ public enum DimensionType { public static boolean isPartitionTime(DimensionType type) { return type == partition_time; } + + public static boolean isIdentity(DimensionType type) { + return type == identify; + } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java index a8fdb6feb..8cf0bc0ac 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java @@ -72,6 +72,8 @@ public class LLMRequestService { llmSchema.setMetrics(getMatchedMetrics(queryCtx, dataSetId)); llmSchema.setDimensions(getMatchedDimensions(queryCtx, dataSetId)); + llmSchema.setPartitionTime(getPartitionTime(queryCtx, dataSetId)); + llmSchema.setPrimaryKey(getPrimaryKey(queryCtx, dataSetId)); llmSchema.setTerms(getTerms(queryCtx, dataSetId)); llmReq.setSchema(llmSchema); @@ -133,12 +135,8 @@ public class LLMRequestService { private String getPriorKnowledge(ChatQueryContext queryContext, LLMReq.LLMSchema llmSchema) { StringBuilder priorKnowledgeBuilder = new StringBuilder(); SemanticSchema semanticSchema = queryContext.getSemanticSchema(); - appendMetricPriorKnowledge(llmSchema, priorKnowledgeBuilder, semanticSchema); - // 处理维度字段 - appendDimensionPriorKnowledge(llmSchema, priorKnowledgeBuilder, semanticSchema); - return priorKnowledgeBuilder.toString(); } @@ -193,27 +191,6 @@ public class LLMRequestService { (k1, k2) -> k1)); } - private void appendDimensionPriorKnowledge( - LLMReq.LLMSchema llmSchema, - StringBuilder priorKnowledgeBuilder, - SemanticSchema semanticSchema) { - Map fieldNameToDateFormat = getFieldNameToDateFormatMap(semanticSchema); - - for (SchemaElement schemaElement : llmSchema.getDimensions()) { - String fieldName = schemaElement.getName(); - String timeFormat = fieldNameToDateFormat.get(fieldName); - if (StringUtils.isBlank(timeFormat)) { - continue; - } - if (schemaElement.containsPartitionTime()) { - priorKnowledgeBuilder.append( - String.format("%s 是分区时间且格式是%s", fieldName, timeFormat)); - } else { - priorKnowledgeBuilder.append(String.format("%s 的时间格式是%s", fieldName, timeFormat)); - } - } - } - public List getValues(ChatQueryContext queryCtx, Long dataSetId) { List matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId); @@ -263,32 +240,39 @@ public class LLMRequestService { return schemaElements; } + protected SchemaElement getPartitionTime(ChatQueryContext queryCtx, Long dataSetId) { + SemanticSchema semanticSchema = queryCtx.getSemanticSchema(); + if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) { + return null; + } + Map dataSetSchemaMap = semanticSchema.getDataSetSchemaMap(); + DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId); + return dataSetSchema.getPartitionDimension(); + } + + protected SchemaElement getPrimaryKey(ChatQueryContext queryCtx, Long dataSetId) { + SemanticSchema semanticSchema = queryCtx.getSemanticSchema(); + if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) { + return null; + } + Map dataSetSchemaMap = semanticSchema.getDataSetSchemaMap(); + DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId); + return dataSetSchema.getPrimaryKey(); + } + protected List getMatchedDimensions(ChatQueryContext queryCtx, Long dataSetId) { List matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId); - Set dimensionElements = + List dimensionElements = matchedElements.stream() .filter( element -> SchemaElementType.DIMENSION.equals( element.getElement().getType())) .map(SchemaElementMatch::getElement) - .collect(Collectors.toSet()); - SemanticSchema semanticSchema = queryCtx.getSemanticSchema(); - if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) { - return new ArrayList<>(dimensionElements); - } + .collect(Collectors.toList()); - Map dataSetSchemaMap = semanticSchema.getDataSetSchemaMap(); - DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId); - if (dataSetSchema == null) { - return new ArrayList<>(dimensionElements); - } - SchemaElement partitionDimension = dataSetSchema.getPartitionDimension(); - if (partitionDimension != null) { - dimensionElements.add(partitionDimension); - } return new ArrayList<>(dimensionElements); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java index d3a2179b1..cdc4bf395 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java @@ -31,11 +31,11 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy { + "please convert it to a SQL query so that relevant data could be returned " + "by executing the SQL query against underlying database." + "\n#Rules:" - + "1.ALWAYS generate column specified in the `Schema`, DO NOT hallucinate." + + "1.ALWAYS generate columns and values specified in the `Schema`, DO NOT hallucinate." + "2.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator." - + "3.ALWAYS calculate the absolute date range by yourself." - + "4.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`." - + "5.DO NOT miss the AGGREGATE operator of metrics, always add it if needed." + + "3.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`." + + "4.DO NOT calculate date range using functions." + + "5.DO NOT miss the AGGREGATE operator of metrics, always add it as needed." + "6.ONLY respond with the converted SQL statement." + "\n#Exemplars:\n{{exemplar}}" + "Question:{{question}},Schema:{{schema}},SideInfo:{{information}},SQL:"; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java index cefe470fe..3ad844a36 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java @@ -133,10 +133,28 @@ public class PromptHelper { values.add(valueStr.toString()); }); - String template = "Table=[%s], Metrics=[%s], Dimensions=[%s], Values=[%s]"; + String partitionTimeStr = ""; + if (llmReq.getSchema().getPartitionTime() != null) { + partitionTimeStr = + String.format( + "%s FORMAT '%s'", + llmReq.getSchema().getPartitionTime().getName(), + llmReq.getSchema().getPartitionTime().getTimeFormat()); + } + + String primaryKeyStr = ""; + if (llmReq.getSchema().getPrimaryKey() != null) { + primaryKeyStr = String.format("%s", llmReq.getSchema().getPrimaryKey().getName()); + } + + String template = + "Table=[%s], PartitionTimeField=[%s], PrimaryKeyField=[%s], " + + "Metrics=[%s], Dimensions=[%s], Values=[%s]"; return String.format( template, tableStr, + partitionTimeStr, + primaryKeyStr, String.join(",", metrics), String.join(",", dimensions), String.join(",", values)); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java index ac86244bd..dcf08c6c9 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java @@ -7,11 +7,8 @@ import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import lombok.Data; -import org.apache.commons.collections4.CollectionUtils; -import java.util.ArrayList; import java.util.List; -import java.util.stream.Collectors; @Data public class LLMReq { @@ -38,24 +35,9 @@ public class LLMReq { private List fieldNameList; private List metrics; private List dimensions; + private SchemaElement partitionTime; + private SchemaElement primaryKey; private List terms; - - public List getFieldNameList() { - List fieldNameList = new ArrayList<>(); - if (CollectionUtils.isNotEmpty(metrics)) { - fieldNameList.addAll( - metrics.stream() - .map(metric -> metric.getName()) - .collect(Collectors.toList())); - } - if (CollectionUtils.isNotEmpty(dimensions)) { - fieldNameList.addAll( - dimensions.stream() - .map(metric -> metric.getName()) - .collect(Collectors.toList())); - } - return fieldNameList; - } } @Data diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java index d7bd189d1..0b559876e 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java @@ -170,7 +170,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery { } for (Entry> entry : id2Values.entrySet()) { SchemaElement dimension = semanticSchema.getElement(entity, entry.getKey()); - if (dimension.containsPartitionTime()) { + if (dimension.isPartitionTime()) { continue; } if (entry.getValue().size() == 1) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/QueryReqBuilder.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/QueryReqBuilder.java index f5d4c15c0..c7ae7ba47 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/QueryReqBuilder.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/QueryReqBuilder.java @@ -61,9 +61,7 @@ public class QueryReqBuilder { addDateDimension(parseInfo); if (isDateFieldAlreadyPresent(parseInfo, getDateField(parseInfo.getDateInfo()))) { - parseInfo - .getDimensions() - .removeIf(schemaElement -> schemaElement.containsPartitionTime()); + parseInfo.getDimensions().removeIf(schemaElement -> schemaElement.isPartitionTime()); } queryStructReq.setGroups( parseInfo.getDimensions().stream()