From 32793ecf6931c4f476b4c7cf6d95d382b5c20d5a Mon Sep 17 00:00:00 2001 From: jerryjzhang Date: Mon, 17 Feb 2025 21:21:43 +0800 Subject: [PATCH] (improvement)(chat)Determine if with statement is supported and send explicitly message in the prompt to the LLM. --- .../headless/api/pojo/DataSetSchema.java | 1 + .../api/pojo/response/DataSetSchemaResp.java | 2 +- .../chat/parser/llm/LLMRequestService.java | 22 +++++------- .../chat/parser/llm/PromptHelper.java | 35 +++++++++++++++---- .../headless/chat/query/llm/s2sql/LLMReq.java | 1 + .../service/impl/SchemaServiceImpl.java | 1 + .../server/utils/DataSetSchemaBuilder.java | 1 + .../tencent/supersonic/demo/S2BaseDemo.java | 7 ++-- .../main/resources/application-docker.yaml | 4 +-- 9 files changed, 49 insertions(+), 25 deletions(-) diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DataSetSchema.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DataSetSchema.java index b0a0f76e3..f33b8c41a 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DataSetSchema.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DataSetSchema.java @@ -16,6 +16,7 @@ import java.util.stream.Collectors; public class DataSetSchema implements Serializable { private String databaseType; + private String databaseVersion; private SchemaElement dataSet; private Set metrics = new HashSet<>(); private Set dimensions = new HashSet<>(); diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/DataSetSchemaResp.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/DataSetSchemaResp.java index e1f373224..c9811e737 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/DataSetSchemaResp.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/DataSetSchemaResp.java @@ -1,7 +1,6 @@ package com.tencent.supersonic.headless.api.pojo.response; import com.google.common.collect.Lists; -import com.tencent.supersonic.headless.api.pojo.Identify; import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor; @@ -14,6 +13,7 @@ import java.util.List; public class DataSetSchemaResp extends DataSetResp { private String databaseType; + private String databaseVersion; private List metrics = Lists.newArrayList(); private List dimensions = Lists.newArrayList(); private List modelResps = Lists.newArrayList(); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java index 3de69f0cc..66ae6ee2b 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java @@ -1,11 +1,8 @@ package com.tencent.supersonic.headless.chat.parser.llm; +import com.tencent.supersonic.common.pojo.Pair; import com.tencent.supersonic.common.util.DateUtils; -import com.tencent.supersonic.headless.api.pojo.DataSetSchema; -import com.tencent.supersonic.headless.api.pojo.SchemaElement; -import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; -import com.tencent.supersonic.headless.api.pojo.SchemaElementType; -import com.tencent.supersonic.headless.api.pojo.SemanticSchema; +import com.tencent.supersonic.headless.api.pojo.*; import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.parser.ParserConfig; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; @@ -17,11 +14,7 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import org.springframework.util.CollectionUtils; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.Set; +import java.util.*; import java.util.stream.Collectors; import static com.tencent.supersonic.headless.chat.parser.ParserConfig.*; @@ -56,7 +49,9 @@ public class LLMRequestService { LLMReq llmReq = new LLMReq(); llmReq.setQueryText(queryText); llmReq.setSchema(llmSchema); - llmSchema.setDatabaseType(getDatabaseType(queryCtx, dataSetId)); + Pair databaseInfo = getDatabaseType(queryCtx, dataSetId); + llmSchema.setDatabaseType(databaseInfo.first); + llmSchema.setDatabaseVersion(databaseInfo.second); llmSchema.setDataSetId(dataSetId); llmSchema.setDataSetName(dataSetIdToName.get(dataSetId)); llmSchema.setPartitionTime(getPartitionTime(queryCtx, dataSetId)); @@ -171,13 +166,14 @@ public class LLMRequestService { return dataSetSchema.getPrimaryKey(); } - protected String getDatabaseType(@NotNull ChatQueryContext queryCtx, Long dataSetId) { + protected Pair getDatabaseType(@NotNull ChatQueryContext queryCtx, + Long dataSetId) { SemanticSchema semanticSchema = queryCtx.getSemanticSchema(); if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) { return null; } Map dataSetSchemaMap = semanticSchema.getDataSetSchemaMap(); DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId); - return dataSetSchema.getDatabaseType(); + return new Pair(dataSetSchema.getDatabaseType(), dataSetSchema.getDatabaseVersion()); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java index 2677458ac..d85572dbb 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java @@ -3,7 +3,9 @@ package com.tencent.supersonic.headless.chat.parser.llm; import com.google.common.collect.Lists; import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum; +import com.tencent.supersonic.common.pojo.enums.EngineType; import com.tencent.supersonic.common.service.ExemplarService; +import com.tencent.supersonic.common.util.StringUtil; import com.tencent.supersonic.headless.chat.parser.ParserConfig; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import lombok.extern.slf4j.Slf4j; @@ -15,10 +17,9 @@ import org.springframework.util.CollectionUtils; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Objects; -import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER; -import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_FEW_SHOT_NUMBER; -import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_SELF_CONSISTENCY_NUMBER; +import static com.tencent.supersonic.headless.chat.parser.ParserConfig.*; @Component @Slf4j @@ -67,6 +68,11 @@ public class PromptHelper { sideInfos.add(String.format("PriorKnowledge=[%s]", llmReq.getPriorExts())); } + LLMReq.LLMSchema schema = llmReq.getSchema(); + if (!isSupportWith(schema.getDatabaseType(), schema.getDatabaseVersion())) { + sideInfos.add("[Database does not support with statement]"); + } + String termStr = buildTermStr(llmReq); if (StringUtils.isNotEmpty(termStr)) { sideInfos.add(String.format("DomainTerms=[%s]", termStr)); @@ -152,12 +158,17 @@ public class PromptHelper { if (llmReq.getSchema().getDatabaseType() != null) { databaseTypeStr = llmReq.getSchema().getDatabaseType(); } + String databaseVersionStr = ""; + if (llmReq.getSchema().getDatabaseVersion() != null) { + databaseVersionStr = llmReq.getSchema().getDatabaseVersion(); + } String template = - "DatabaseType=[%s], Table=[%s], PartitionTimeField=[%s], PrimaryKeyField=[%s], " + "DatabaseType=[%s], DatabaseVersion=[%s], Table=[%s], PartitionTimeField=[%s], PrimaryKeyField=[%s], " + "Metrics=[%s], Dimensions=[%s], Values=[%s]"; - return String.format(template, databaseTypeStr, tableStr, partitionTimeStr, primaryKeyStr, - String.join(",", metrics), String.join(",", dimensions), String.join(",", values)); + return String.format(template, databaseTypeStr, databaseVersionStr, tableStr, + partitionTimeStr, primaryKeyStr, String.join(",", metrics), + String.join(",", dimensions), String.join(",", values)); } private String buildTermStr(LLMReq llmReq) { @@ -176,4 +187,16 @@ public class PromptHelper { return ret; } + + public static boolean isSupportWith(String type, String version) { + if (type.equalsIgnoreCase(EngineType.MYSQL.getName()) && Objects.nonNull(version) + && StringUtil.compareVersion(version, "8.0") < 0) { + return false; + } + if (type.equalsIgnoreCase(EngineType.CLICKHOUSE.getName()) && Objects.nonNull(version) + && StringUtil.compareVersion(version, "20.4") < 0) { + return false; + } + return true; + } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java index c3a8c4228..cce5ec983 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java @@ -38,6 +38,7 @@ public class LLMReq { @Data public static class LLMSchema { private String databaseType; + private String databaseVersion; private Long dataSetId; private String dataSetName; private List metrics; diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/SchemaServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/SchemaServiceImpl.java index 361fffe51..e7b24c616 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/SchemaServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/SchemaServiceImpl.java @@ -197,6 +197,7 @@ public class SchemaServiceImpl implements SchemaService { DatabaseResp databaseResp = databaseService .getDatabase(dataSetSchemaResp.getModelResps().get(0).getDatabaseId()); dataSetSchemaResp.setDatabaseType(databaseResp.getType()); + dataSetSchemaResp.setDatabaseVersion(databaseResp.getVersion()); } dataSetSchemaResps.add(dataSetSchemaResp); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DataSetSchemaBuilder.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DataSetSchemaBuilder.java index eca5c6712..3befbb47d 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DataSetSchemaBuilder.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DataSetSchemaBuilder.java @@ -24,6 +24,7 @@ public class DataSetSchemaBuilder { .bizName(resp.getBizName()).type(SchemaElementType.DATASET).build(); dataSetSchema.setDataSet(dataSet); dataSetSchema.setDatabaseType(resp.getDatabaseType()); + dataSetSchema.setDatabaseVersion(resp.getDatabaseVersion()); Set metrics = getMetrics(resp); dataSetSchema.getMetrics().addAll(metrics); diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2BaseDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2BaseDemo.java index 6a568c235..3b88aeed1 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2BaseDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2BaseDemo.java @@ -124,12 +124,13 @@ public abstract class S2BaseDemo implements CommandLineRunner { DatabaseReq databaseReq = new DatabaseReq(); databaseReq.setName("S2数据库DEMO"); databaseReq.setDescription("样例数据库实例仅用于体验"); - databaseReq.setType(DataType.H2.getFeature()); + databaseReq.setType(DataType.H2.toString()); if ("org.postgresql.Driver".equals(driverClassName)) { - databaseReq.setType(DataType.POSTGRESQL.getFeature()); + databaseReq.setType(DataType.POSTGRESQL.toString()); } else if ("com.mysql.cj.jdbc.Driver".equals(driverClassName) || "com.mysql.jdbc.Driver".equals(driverClassName)) { - databaseReq.setType(DataType.MYSQL.getFeature()); + databaseReq.setType(DataType.MYSQL.toString()); + databaseReq.setVersion("5.7"); } databaseReq.setUrl(url); databaseReq.setUsername(dataSourceProperties.getUsername()); diff --git a/launchers/standalone/src/main/resources/application-docker.yaml b/launchers/standalone/src/main/resources/application-docker.yaml index df50ed6c5..8ec7476d5 100644 --- a/launchers/standalone/src/main/resources/application-docker.yaml +++ b/launchers/standalone/src/main/resources/application-docker.yaml @@ -1,12 +1,12 @@ spring: datasource: + driver-class-name: org.postgresql.Driver url: jdbc:postgresql://${DB_HOST}:${DB_PORT:5432}/${DB_NAME}?stringtype=unspecified username: ${DB_USERNAME} password: ${DB_PASSWORD} - driver-class-name: org.postgresql.Driver + sql: init: - enabled: false mode: always username: ${DB_USERNAME} password: ${DB_PASSWORD}