mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-15 06:27:21 +00:00
[improvement][headless]Add databaseType into the Schema part of the Text2SQL prompt. #1621
This commit is contained in:
@@ -16,6 +16,7 @@ import java.util.stream.Collectors;
|
||||
@Data
|
||||
public class DataSetSchema {
|
||||
|
||||
private String databaseType;
|
||||
private SchemaElement dataSet;
|
||||
private Set<SchemaElement> metrics = new HashSet<>();
|
||||
private Set<SchemaElement> dimensions = new HashSet<>();
|
||||
|
||||
@@ -13,6 +13,7 @@ import java.util.List;
|
||||
@NoArgsConstructor
|
||||
public class DataSetSchemaResp extends DataSetResp {
|
||||
|
||||
private String databaseType;
|
||||
private List<MetricSchemaResp> metrics = Lists.newArrayList();
|
||||
private List<DimSchemaResp> dimensions = Lists.newArrayList();
|
||||
private List<ModelResp> modelResps = Lists.newArrayList();
|
||||
|
||||
@@ -61,6 +61,7 @@ public class LLMRequestService {
|
||||
llmReq.setQueryText(queryText);
|
||||
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
|
||||
llmReq.setSchema(llmSchema);
|
||||
llmSchema.setDatabaseType(getDatabaseType(queryCtx, dataSetId));
|
||||
llmSchema.setDataSetId(dataSetId);
|
||||
llmSchema.setDataSetName(dataSetIdToName.get(dataSetId));
|
||||
llmSchema.setMetrics(getMappedMetrics(queryCtx, dataSetId));
|
||||
@@ -205,4 +206,14 @@ public class LLMRequestService {
|
||||
DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId);
|
||||
return dataSetSchema.getPrimaryKey();
|
||||
}
|
||||
|
||||
protected String getDatabaseType(@NotNull ChatQueryContext queryCtx, Long dataSetId) {
|
||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||
if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) {
|
||||
return null;
|
||||
}
|
||||
Map<Long, DataSetSchema> dataSetSchemaMap = semanticSchema.getDataSetSchemaMap();
|
||||
DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId);
|
||||
return dataSetSchema.getDatabaseType();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -162,11 +162,17 @@ public class PromptHelper {
|
||||
primaryKeyStr = String.format("%s", llmReq.getSchema().getPrimaryKey().getName());
|
||||
}
|
||||
|
||||
String databaseTypeStr = "";
|
||||
if (llmReq.getSchema().getDatabaseType() != null) {
|
||||
databaseTypeStr = llmReq.getSchema().getDatabaseType();
|
||||
}
|
||||
|
||||
String template =
|
||||
"Table=[%s], PartitionTimeField=[%s], PrimaryKeyField=[%s], "
|
||||
"DatabaseType=[%s], Table=[%s], PartitionTimeField=[%s], PrimaryKeyField=[%s], "
|
||||
+ "Metrics=[%s], Dimensions=[%s], Values=[%s]";
|
||||
return String.format(
|
||||
template,
|
||||
databaseTypeStr,
|
||||
tableStr,
|
||||
partitionTimeStr,
|
||||
primaryKeyStr,
|
||||
|
||||
@@ -33,6 +33,7 @@ public class LLMReq {
|
||||
|
||||
@Data
|
||||
public static class LLMSchema {
|
||||
private String databaseType;
|
||||
private Long dataSetId;
|
||||
private String dataSetName;
|
||||
private List<SchemaElement> metrics;
|
||||
|
||||
@@ -246,6 +246,12 @@ public class SchemaServiceImpl implements SchemaService {
|
||||
.collect(Collectors.toList()));
|
||||
dataSetSchemaResp.setTermResps(
|
||||
termMaps.getOrDefault(dataSetResp.getDomainId(), Lists.newArrayList()));
|
||||
if (!CollectionUtils.isEmpty(dataSetSchemaResp.getModelResps())) {
|
||||
DatabaseResp databaseResp =
|
||||
databaseService.getDatabase(
|
||||
dataSetSchemaResp.getModelResps().get(0).getDatabaseId());
|
||||
dataSetSchemaResp.setDatabaseType(databaseResp.getType());
|
||||
}
|
||||
dataSetSchemaResps.add(dataSetSchemaResp);
|
||||
}
|
||||
fillStaticInfo(dataSetSchemaResps);
|
||||
|
||||
@@ -41,6 +41,7 @@ public class DataSetSchemaBuilder {
|
||||
.type(SchemaElementType.DATASET)
|
||||
.build();
|
||||
dataSetSchema.setDataSet(dataSet);
|
||||
dataSetSchema.setDatabaseType(resp.getDatabaseType());
|
||||
|
||||
Set<SchemaElement> metrics = getMetrics(resp);
|
||||
dataSetSchema.getMetrics().addAll(metrics);
|
||||
|
||||
Reference in New Issue
Block a user