[improvement][headless]Add databaseType into the Schema part of the Text2SQL prompt. #1621

This commit is contained in:
jerryjzhang
2024-09-12 23:20:49 +08:00
parent 47cc933aec
commit 37f12391b0
9 changed files with 44 additions and 17 deletions

View File

@@ -16,6 +16,7 @@ import java.util.stream.Collectors;
@Data
public class DataSetSchema {
private String databaseType;
private SchemaElement dataSet;
private Set<SchemaElement> metrics = new HashSet<>();
private Set<SchemaElement> dimensions = new HashSet<>();

View File

@@ -13,6 +13,7 @@ import java.util.List;
@NoArgsConstructor
public class DataSetSchemaResp extends DataSetResp {
private String databaseType;
private List<MetricSchemaResp> metrics = Lists.newArrayList();
private List<DimSchemaResp> dimensions = Lists.newArrayList();
private List<ModelResp> modelResps = Lists.newArrayList();

View File

@@ -61,6 +61,7 @@ public class LLMRequestService {
llmReq.setQueryText(queryText);
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
llmReq.setSchema(llmSchema);
llmSchema.setDatabaseType(getDatabaseType(queryCtx, dataSetId));
llmSchema.setDataSetId(dataSetId);
llmSchema.setDataSetName(dataSetIdToName.get(dataSetId));
llmSchema.setMetrics(getMappedMetrics(queryCtx, dataSetId));
@@ -205,4 +206,14 @@ public class LLMRequestService {
DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId);
return dataSetSchema.getPrimaryKey();
}
protected String getDatabaseType(@NotNull ChatQueryContext queryCtx, Long dataSetId) {
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) {
return null;
}
Map<Long, DataSetSchema> dataSetSchemaMap = semanticSchema.getDataSetSchemaMap();
DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId);
return dataSetSchema.getDatabaseType();
}
}

View File

@@ -162,11 +162,17 @@ public class PromptHelper {
primaryKeyStr = String.format("%s", llmReq.getSchema().getPrimaryKey().getName());
}
String databaseTypeStr = "";
if (llmReq.getSchema().getDatabaseType() != null) {
databaseTypeStr = llmReq.getSchema().getDatabaseType();
}
String template =
"Table=[%s], PartitionTimeField=[%s], PrimaryKeyField=[%s], "
"DatabaseType=[%s], Table=[%s], PartitionTimeField=[%s], PrimaryKeyField=[%s], "
+ "Metrics=[%s], Dimensions=[%s], Values=[%s]";
return String.format(
template,
databaseTypeStr,
tableStr,
partitionTimeStr,
primaryKeyStr,

View File

@@ -33,6 +33,7 @@ public class LLMReq {
@Data
public static class LLMSchema {
private String databaseType;
private Long dataSetId;
private String dataSetName;
private List<SchemaElement> metrics;

View File

@@ -246,6 +246,12 @@ public class SchemaServiceImpl implements SchemaService {
.collect(Collectors.toList()));
dataSetSchemaResp.setTermResps(
termMaps.getOrDefault(dataSetResp.getDomainId(), Lists.newArrayList()));
if (!CollectionUtils.isEmpty(dataSetSchemaResp.getModelResps())) {
DatabaseResp databaseResp =
databaseService.getDatabase(
dataSetSchemaResp.getModelResps().get(0).getDatabaseId());
dataSetSchemaResp.setDatabaseType(databaseResp.getType());
}
dataSetSchemaResps.add(dataSetSchemaResp);
}
fillStaticInfo(dataSetSchemaResps);

View File

@@ -41,6 +41,7 @@ public class DataSetSchemaBuilder {
.type(SchemaElementType.DATASET)
.build();
dataSetSchema.setDataSet(dataSet);
dataSetSchema.setDatabaseType(resp.getDatabaseType());
Set<SchemaElement> metrics = getMetrics(resp);
dataSetSchema.getMetrics().addAll(metrics);