(improvement)(chat)Determine if with statement is supported and send explicitly message in the prompt to the LLM.

This commit is contained in:
jerryjzhang
2025-02-17 21:21:43 +08:00
parent f31db98aba
commit 32793ecf69
9 changed files with 49 additions and 25 deletions

View File

@@ -16,6 +16,7 @@ import java.util.stream.Collectors;
public class DataSetSchema implements Serializable {
private String databaseType;
private String databaseVersion;
private SchemaElement dataSet;
private Set<SchemaElement> metrics = new HashSet<>();
private Set<SchemaElement> dimensions = new HashSet<>();

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.headless.api.pojo.response;
import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.api.pojo.Identify;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@@ -14,6 +13,7 @@ import java.util.List;
public class DataSetSchemaResp extends DataSetResp {
private String databaseType;
private String databaseVersion;
private List<MetricSchemaResp> metrics = Lists.newArrayList();
private List<DimSchemaResp> dimensions = Lists.newArrayList();
private List<ModelResp> modelResps = Lists.newArrayList();

View File

@@ -1,11 +1,8 @@
package com.tencent.supersonic.headless.chat.parser.llm;
import com.tencent.supersonic.common.pojo.Pair;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.*;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
@@ -17,11 +14,7 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.*;
import java.util.stream.Collectors;
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.*;
@@ -56,7 +49,9 @@ public class LLMRequestService {
LLMReq llmReq = new LLMReq();
llmReq.setQueryText(queryText);
llmReq.setSchema(llmSchema);
llmSchema.setDatabaseType(getDatabaseType(queryCtx, dataSetId));
Pair<String, String> databaseInfo = getDatabaseType(queryCtx, dataSetId);
llmSchema.setDatabaseType(databaseInfo.first);
llmSchema.setDatabaseVersion(databaseInfo.second);
llmSchema.setDataSetId(dataSetId);
llmSchema.setDataSetName(dataSetIdToName.get(dataSetId));
llmSchema.setPartitionTime(getPartitionTime(queryCtx, dataSetId));
@@ -171,13 +166,14 @@ public class LLMRequestService {
return dataSetSchema.getPrimaryKey();
}
protected String getDatabaseType(@NotNull ChatQueryContext queryCtx, Long dataSetId) {
protected Pair<String, String> getDatabaseType(@NotNull ChatQueryContext queryCtx,
Long dataSetId) {
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) {
return null;
}
Map<Long, DataSetSchema> dataSetSchemaMap = semanticSchema.getDataSetSchemaMap();
DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId);
return dataSetSchema.getDatabaseType();
return new Pair(dataSetSchema.getDatabaseType(), dataSetSchema.getDatabaseVersion());
}
}

View File

@@ -3,7 +3,9 @@ package com.tencent.supersonic.headless.chat.parser.llm;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
import com.tencent.supersonic.common.pojo.enums.EngineType;
import com.tencent.supersonic.common.service.ExemplarService;
import com.tencent.supersonic.common.util.StringUtil;
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import lombok.extern.slf4j.Slf4j;
@@ -15,10 +17,9 @@ import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER;
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_FEW_SHOT_NUMBER;
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_SELF_CONSISTENCY_NUMBER;
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.*;
@Component
@Slf4j
@@ -67,6 +68,11 @@ public class PromptHelper {
sideInfos.add(String.format("PriorKnowledge=[%s]", llmReq.getPriorExts()));
}
LLMReq.LLMSchema schema = llmReq.getSchema();
if (!isSupportWith(schema.getDatabaseType(), schema.getDatabaseVersion())) {
sideInfos.add("[Database does not support with statement]");
}
String termStr = buildTermStr(llmReq);
if (StringUtils.isNotEmpty(termStr)) {
sideInfos.add(String.format("DomainTerms=[%s]", termStr));
@@ -152,12 +158,17 @@ public class PromptHelper {
if (llmReq.getSchema().getDatabaseType() != null) {
databaseTypeStr = llmReq.getSchema().getDatabaseType();
}
String databaseVersionStr = "";
if (llmReq.getSchema().getDatabaseVersion() != null) {
databaseVersionStr = llmReq.getSchema().getDatabaseVersion();
}
String template =
"DatabaseType=[%s], Table=[%s], PartitionTimeField=[%s], PrimaryKeyField=[%s], "
"DatabaseType=[%s], DatabaseVersion=[%s], Table=[%s], PartitionTimeField=[%s], PrimaryKeyField=[%s], "
+ "Metrics=[%s], Dimensions=[%s], Values=[%s]";
return String.format(template, databaseTypeStr, tableStr, partitionTimeStr, primaryKeyStr,
String.join(",", metrics), String.join(",", dimensions), String.join(",", values));
return String.format(template, databaseTypeStr, databaseVersionStr, tableStr,
partitionTimeStr, primaryKeyStr, String.join(",", metrics),
String.join(",", dimensions), String.join(",", values));
}
private String buildTermStr(LLMReq llmReq) {
@@ -176,4 +187,16 @@ public class PromptHelper {
return ret;
}
public static boolean isSupportWith(String type, String version) {
if (type.equalsIgnoreCase(EngineType.MYSQL.getName()) && Objects.nonNull(version)
&& StringUtil.compareVersion(version, "8.0") < 0) {
return false;
}
if (type.equalsIgnoreCase(EngineType.CLICKHOUSE.getName()) && Objects.nonNull(version)
&& StringUtil.compareVersion(version, "20.4") < 0) {
return false;
}
return true;
}
}

View File

@@ -38,6 +38,7 @@ public class LLMReq {
@Data
public static class LLMSchema {
private String databaseType;
private String databaseVersion;
private Long dataSetId;
private String dataSetName;
private List<SchemaElement> metrics;

View File

@@ -197,6 +197,7 @@ public class SchemaServiceImpl implements SchemaService {
DatabaseResp databaseResp = databaseService
.getDatabase(dataSetSchemaResp.getModelResps().get(0).getDatabaseId());
dataSetSchemaResp.setDatabaseType(databaseResp.getType());
dataSetSchemaResp.setDatabaseVersion(databaseResp.getVersion());
}
dataSetSchemaResps.add(dataSetSchemaResp);
}

View File

@@ -24,6 +24,7 @@ public class DataSetSchemaBuilder {
.bizName(resp.getBizName()).type(SchemaElementType.DATASET).build();
dataSetSchema.setDataSet(dataSet);
dataSetSchema.setDatabaseType(resp.getDatabaseType());
dataSetSchema.setDatabaseVersion(resp.getDatabaseVersion());
Set<SchemaElement> metrics = getMetrics(resp);
dataSetSchema.getMetrics().addAll(metrics);

View File

@@ -124,12 +124,13 @@ public abstract class S2BaseDemo implements CommandLineRunner {
DatabaseReq databaseReq = new DatabaseReq();
databaseReq.setName("S2数据库DEMO");
databaseReq.setDescription("样例数据库实例仅用于体验");
databaseReq.setType(DataType.H2.getFeature());
databaseReq.setType(DataType.H2.toString());
if ("org.postgresql.Driver".equals(driverClassName)) {
databaseReq.setType(DataType.POSTGRESQL.getFeature());
databaseReq.setType(DataType.POSTGRESQL.toString());
} else if ("com.mysql.cj.jdbc.Driver".equals(driverClassName)
|| "com.mysql.jdbc.Driver".equals(driverClassName)) {
databaseReq.setType(DataType.MYSQL.getFeature());
databaseReq.setType(DataType.MYSQL.toString());
databaseReq.setVersion("5.7");
}
databaseReq.setUrl(url);
databaseReq.setUsername(dataSourceProperties.getUsername());

View File

@@ -1,12 +1,12 @@
spring:
datasource:
driver-class-name: org.postgresql.Driver
url: jdbc:postgresql://${DB_HOST}:${DB_PORT:5432}/${DB_NAME}?stringtype=unspecified
username: ${DB_USERNAME}
password: ${DB_PASSWORD}
driver-class-name: org.postgresql.Driver
sql:
init:
enabled: false
mode: always
username: ${DB_USERNAME}
password: ${DB_PASSWORD}