mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
(improvement)(chat)Determine if with statement is supported and send explicitly message in the prompt to the LLM.
This commit is contained in:
@@ -16,6 +16,7 @@ import java.util.stream.Collectors;
|
||||
public class DataSetSchema implements Serializable {
|
||||
|
||||
private String databaseType;
|
||||
private String databaseVersion;
|
||||
private SchemaElement dataSet;
|
||||
private Set<SchemaElement> metrics = new HashSet<>();
|
||||
private Set<SchemaElement> dimensions = new HashSet<>();
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package com.tencent.supersonic.headless.api.pojo.response;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.headless.api.pojo.Identify;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
@@ -14,6 +13,7 @@ import java.util.List;
|
||||
public class DataSetSchemaResp extends DataSetResp {
|
||||
|
||||
private String databaseType;
|
||||
private String databaseVersion;
|
||||
private List<MetricSchemaResp> metrics = Lists.newArrayList();
|
||||
private List<DimSchemaResp> dimensions = Lists.newArrayList();
|
||||
private List<ModelResp> modelResps = Lists.newArrayList();
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Pair;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.*;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
@@ -17,11 +14,7 @@ import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.*;
|
||||
@@ -56,7 +49,9 @@ public class LLMRequestService {
|
||||
LLMReq llmReq = new LLMReq();
|
||||
llmReq.setQueryText(queryText);
|
||||
llmReq.setSchema(llmSchema);
|
||||
llmSchema.setDatabaseType(getDatabaseType(queryCtx, dataSetId));
|
||||
Pair<String, String> databaseInfo = getDatabaseType(queryCtx, dataSetId);
|
||||
llmSchema.setDatabaseType(databaseInfo.first);
|
||||
llmSchema.setDatabaseVersion(databaseInfo.second);
|
||||
llmSchema.setDataSetId(dataSetId);
|
||||
llmSchema.setDataSetName(dataSetIdToName.get(dataSetId));
|
||||
llmSchema.setPartitionTime(getPartitionTime(queryCtx, dataSetId));
|
||||
@@ -171,13 +166,14 @@ public class LLMRequestService {
|
||||
return dataSetSchema.getPrimaryKey();
|
||||
}
|
||||
|
||||
protected String getDatabaseType(@NotNull ChatQueryContext queryCtx, Long dataSetId) {
|
||||
protected Pair<String, String> getDatabaseType(@NotNull ChatQueryContext queryCtx,
|
||||
Long dataSetId) {
|
||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||
if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) {
|
||||
return null;
|
||||
}
|
||||
Map<Long, DataSetSchema> dataSetSchemaMap = semanticSchema.getDataSetSchemaMap();
|
||||
DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId);
|
||||
return dataSetSchema.getDatabaseType();
|
||||
return new Pair(dataSetSchema.getDatabaseType(), dataSetSchema.getDatabaseVersion());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,9 @@ package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.EngineType;
|
||||
import com.tencent.supersonic.common.service.ExemplarService;
|
||||
import com.tencent.supersonic.common.util.StringUtil;
|
||||
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -15,10 +17,9 @@ import org.springframework.util.CollectionUtils;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER;
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_FEW_SHOT_NUMBER;
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_SELF_CONSISTENCY_NUMBER;
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.*;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
@@ -67,6 +68,11 @@ public class PromptHelper {
|
||||
sideInfos.add(String.format("PriorKnowledge=[%s]", llmReq.getPriorExts()));
|
||||
}
|
||||
|
||||
LLMReq.LLMSchema schema = llmReq.getSchema();
|
||||
if (!isSupportWith(schema.getDatabaseType(), schema.getDatabaseVersion())) {
|
||||
sideInfos.add("[Database does not support with statement]");
|
||||
}
|
||||
|
||||
String termStr = buildTermStr(llmReq);
|
||||
if (StringUtils.isNotEmpty(termStr)) {
|
||||
sideInfos.add(String.format("DomainTerms=[%s]", termStr));
|
||||
@@ -152,12 +158,17 @@ public class PromptHelper {
|
||||
if (llmReq.getSchema().getDatabaseType() != null) {
|
||||
databaseTypeStr = llmReq.getSchema().getDatabaseType();
|
||||
}
|
||||
String databaseVersionStr = "";
|
||||
if (llmReq.getSchema().getDatabaseVersion() != null) {
|
||||
databaseVersionStr = llmReq.getSchema().getDatabaseVersion();
|
||||
}
|
||||
|
||||
String template =
|
||||
"DatabaseType=[%s], Table=[%s], PartitionTimeField=[%s], PrimaryKeyField=[%s], "
|
||||
"DatabaseType=[%s], DatabaseVersion=[%s], Table=[%s], PartitionTimeField=[%s], PrimaryKeyField=[%s], "
|
||||
+ "Metrics=[%s], Dimensions=[%s], Values=[%s]";
|
||||
return String.format(template, databaseTypeStr, tableStr, partitionTimeStr, primaryKeyStr,
|
||||
String.join(",", metrics), String.join(",", dimensions), String.join(",", values));
|
||||
return String.format(template, databaseTypeStr, databaseVersionStr, tableStr,
|
||||
partitionTimeStr, primaryKeyStr, String.join(",", metrics),
|
||||
String.join(",", dimensions), String.join(",", values));
|
||||
}
|
||||
|
||||
private String buildTermStr(LLMReq llmReq) {
|
||||
@@ -176,4 +187,16 @@ public class PromptHelper {
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
public static boolean isSupportWith(String type, String version) {
|
||||
if (type.equalsIgnoreCase(EngineType.MYSQL.getName()) && Objects.nonNull(version)
|
||||
&& StringUtil.compareVersion(version, "8.0") < 0) {
|
||||
return false;
|
||||
}
|
||||
if (type.equalsIgnoreCase(EngineType.CLICKHOUSE.getName()) && Objects.nonNull(version)
|
||||
&& StringUtil.compareVersion(version, "20.4") < 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,6 +38,7 @@ public class LLMReq {
|
||||
@Data
|
||||
public static class LLMSchema {
|
||||
private String databaseType;
|
||||
private String databaseVersion;
|
||||
private Long dataSetId;
|
||||
private String dataSetName;
|
||||
private List<SchemaElement> metrics;
|
||||
|
||||
@@ -197,6 +197,7 @@ public class SchemaServiceImpl implements SchemaService {
|
||||
DatabaseResp databaseResp = databaseService
|
||||
.getDatabase(dataSetSchemaResp.getModelResps().get(0).getDatabaseId());
|
||||
dataSetSchemaResp.setDatabaseType(databaseResp.getType());
|
||||
dataSetSchemaResp.setDatabaseVersion(databaseResp.getVersion());
|
||||
}
|
||||
dataSetSchemaResps.add(dataSetSchemaResp);
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ public class DataSetSchemaBuilder {
|
||||
.bizName(resp.getBizName()).type(SchemaElementType.DATASET).build();
|
||||
dataSetSchema.setDataSet(dataSet);
|
||||
dataSetSchema.setDatabaseType(resp.getDatabaseType());
|
||||
dataSetSchema.setDatabaseVersion(resp.getDatabaseVersion());
|
||||
|
||||
Set<SchemaElement> metrics = getMetrics(resp);
|
||||
dataSetSchema.getMetrics().addAll(metrics);
|
||||
|
||||
@@ -124,12 +124,13 @@ public abstract class S2BaseDemo implements CommandLineRunner {
|
||||
DatabaseReq databaseReq = new DatabaseReq();
|
||||
databaseReq.setName("S2数据库DEMO");
|
||||
databaseReq.setDescription("样例数据库实例仅用于体验");
|
||||
databaseReq.setType(DataType.H2.getFeature());
|
||||
databaseReq.setType(DataType.H2.toString());
|
||||
if ("org.postgresql.Driver".equals(driverClassName)) {
|
||||
databaseReq.setType(DataType.POSTGRESQL.getFeature());
|
||||
databaseReq.setType(DataType.POSTGRESQL.toString());
|
||||
} else if ("com.mysql.cj.jdbc.Driver".equals(driverClassName)
|
||||
|| "com.mysql.jdbc.Driver".equals(driverClassName)) {
|
||||
databaseReq.setType(DataType.MYSQL.getFeature());
|
||||
databaseReq.setType(DataType.MYSQL.toString());
|
||||
databaseReq.setVersion("5.7");
|
||||
}
|
||||
databaseReq.setUrl(url);
|
||||
databaseReq.setUsername(dataSourceProperties.getUsername());
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
spring:
|
||||
datasource:
|
||||
driver-class-name: org.postgresql.Driver
|
||||
url: jdbc:postgresql://${DB_HOST}:${DB_PORT:5432}/${DB_NAME}?stringtype=unspecified
|
||||
username: ${DB_USERNAME}
|
||||
password: ${DB_PASSWORD}
|
||||
driver-class-name: org.postgresql.Driver
|
||||
|
||||
sql:
|
||||
init:
|
||||
enabled: false
|
||||
mode: always
|
||||
username: ${DB_USERNAME}
|
||||
password: ${DB_PASSWORD}
|
||||
|
||||
Reference in New Issue
Block a user