mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 04:27:39 +00:00
(improvement)(chat)Determine if with statement is supported and send explicitly message in the prompt to the LLM.
This commit is contained in:
@@ -16,6 +16,7 @@ import java.util.stream.Collectors;
|
|||||||
public class DataSetSchema implements Serializable {
|
public class DataSetSchema implements Serializable {
|
||||||
|
|
||||||
private String databaseType;
|
private String databaseType;
|
||||||
|
private String databaseVersion;
|
||||||
private SchemaElement dataSet;
|
private SchemaElement dataSet;
|
||||||
private Set<SchemaElement> metrics = new HashSet<>();
|
private Set<SchemaElement> metrics = new HashSet<>();
|
||||||
private Set<SchemaElement> dimensions = new HashSet<>();
|
private Set<SchemaElement> dimensions = new HashSet<>();
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package com.tencent.supersonic.headless.api.pojo.response;
|
package com.tencent.supersonic.headless.api.pojo.response;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.headless.api.pojo.Identify;
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
@@ -14,6 +13,7 @@ import java.util.List;
|
|||||||
public class DataSetSchemaResp extends DataSetResp {
|
public class DataSetSchemaResp extends DataSetResp {
|
||||||
|
|
||||||
private String databaseType;
|
private String databaseType;
|
||||||
|
private String databaseVersion;
|
||||||
private List<MetricSchemaResp> metrics = Lists.newArrayList();
|
private List<MetricSchemaResp> metrics = Lists.newArrayList();
|
||||||
private List<DimSchemaResp> dimensions = Lists.newArrayList();
|
private List<DimSchemaResp> dimensions = Lists.newArrayList();
|
||||||
private List<ModelResp> modelResps = Lists.newArrayList();
|
private List<ModelResp> modelResps = Lists.newArrayList();
|
||||||
|
|||||||
@@ -1,11 +1,8 @@
|
|||||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.pojo.Pair;
|
||||||
import com.tencent.supersonic.common.util.DateUtils;
|
import com.tencent.supersonic.common.util.DateUtils;
|
||||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
import com.tencent.supersonic.headless.api.pojo.*;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
|
||||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||||
@@ -17,11 +14,7 @@ import org.springframework.beans.factory.annotation.Autowired;
|
|||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.*;
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.*;
|
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.*;
|
||||||
@@ -56,7 +49,9 @@ public class LLMRequestService {
|
|||||||
LLMReq llmReq = new LLMReq();
|
LLMReq llmReq = new LLMReq();
|
||||||
llmReq.setQueryText(queryText);
|
llmReq.setQueryText(queryText);
|
||||||
llmReq.setSchema(llmSchema);
|
llmReq.setSchema(llmSchema);
|
||||||
llmSchema.setDatabaseType(getDatabaseType(queryCtx, dataSetId));
|
Pair<String, String> databaseInfo = getDatabaseType(queryCtx, dataSetId);
|
||||||
|
llmSchema.setDatabaseType(databaseInfo.first);
|
||||||
|
llmSchema.setDatabaseVersion(databaseInfo.second);
|
||||||
llmSchema.setDataSetId(dataSetId);
|
llmSchema.setDataSetId(dataSetId);
|
||||||
llmSchema.setDataSetName(dataSetIdToName.get(dataSetId));
|
llmSchema.setDataSetName(dataSetIdToName.get(dataSetId));
|
||||||
llmSchema.setPartitionTime(getPartitionTime(queryCtx, dataSetId));
|
llmSchema.setPartitionTime(getPartitionTime(queryCtx, dataSetId));
|
||||||
@@ -171,13 +166,14 @@ public class LLMRequestService {
|
|||||||
return dataSetSchema.getPrimaryKey();
|
return dataSetSchema.getPrimaryKey();
|
||||||
}
|
}
|
||||||
|
|
||||||
protected String getDatabaseType(@NotNull ChatQueryContext queryCtx, Long dataSetId) {
|
protected Pair<String, String> getDatabaseType(@NotNull ChatQueryContext queryCtx,
|
||||||
|
Long dataSetId) {
|
||||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||||
if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) {
|
if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
Map<Long, DataSetSchema> dataSetSchemaMap = semanticSchema.getDataSetSchemaMap();
|
Map<Long, DataSetSchema> dataSetSchemaMap = semanticSchema.getDataSetSchemaMap();
|
||||||
DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId);
|
DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId);
|
||||||
return dataSetSchema.getDatabaseType();
|
return new Pair(dataSetSchema.getDatabaseType(), dataSetSchema.getDatabaseVersion());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,9 @@ package com.tencent.supersonic.headless.chat.parser.llm;
|
|||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||||
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
|
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.EngineType;
|
||||||
import com.tencent.supersonic.common.service.ExemplarService;
|
import com.tencent.supersonic.common.service.ExemplarService;
|
||||||
|
import com.tencent.supersonic.common.util.StringUtil;
|
||||||
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
@@ -15,10 +17,9 @@ import org.springframework.util.CollectionUtils;
|
|||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER;
|
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.*;
|
||||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_FEW_SHOT_NUMBER;
|
|
||||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_SELF_CONSISTENCY_NUMBER;
|
|
||||||
|
|
||||||
@Component
|
@Component
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@@ -67,6 +68,11 @@ public class PromptHelper {
|
|||||||
sideInfos.add(String.format("PriorKnowledge=[%s]", llmReq.getPriorExts()));
|
sideInfos.add(String.format("PriorKnowledge=[%s]", llmReq.getPriorExts()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LLMReq.LLMSchema schema = llmReq.getSchema();
|
||||||
|
if (!isSupportWith(schema.getDatabaseType(), schema.getDatabaseVersion())) {
|
||||||
|
sideInfos.add("[Database does not support with statement]");
|
||||||
|
}
|
||||||
|
|
||||||
String termStr = buildTermStr(llmReq);
|
String termStr = buildTermStr(llmReq);
|
||||||
if (StringUtils.isNotEmpty(termStr)) {
|
if (StringUtils.isNotEmpty(termStr)) {
|
||||||
sideInfos.add(String.format("DomainTerms=[%s]", termStr));
|
sideInfos.add(String.format("DomainTerms=[%s]", termStr));
|
||||||
@@ -152,12 +158,17 @@ public class PromptHelper {
|
|||||||
if (llmReq.getSchema().getDatabaseType() != null) {
|
if (llmReq.getSchema().getDatabaseType() != null) {
|
||||||
databaseTypeStr = llmReq.getSchema().getDatabaseType();
|
databaseTypeStr = llmReq.getSchema().getDatabaseType();
|
||||||
}
|
}
|
||||||
|
String databaseVersionStr = "";
|
||||||
|
if (llmReq.getSchema().getDatabaseVersion() != null) {
|
||||||
|
databaseVersionStr = llmReq.getSchema().getDatabaseVersion();
|
||||||
|
}
|
||||||
|
|
||||||
String template =
|
String template =
|
||||||
"DatabaseType=[%s], Table=[%s], PartitionTimeField=[%s], PrimaryKeyField=[%s], "
|
"DatabaseType=[%s], DatabaseVersion=[%s], Table=[%s], PartitionTimeField=[%s], PrimaryKeyField=[%s], "
|
||||||
+ "Metrics=[%s], Dimensions=[%s], Values=[%s]";
|
+ "Metrics=[%s], Dimensions=[%s], Values=[%s]";
|
||||||
return String.format(template, databaseTypeStr, tableStr, partitionTimeStr, primaryKeyStr,
|
return String.format(template, databaseTypeStr, databaseVersionStr, tableStr,
|
||||||
String.join(",", metrics), String.join(",", dimensions), String.join(",", values));
|
partitionTimeStr, primaryKeyStr, String.join(",", metrics),
|
||||||
|
String.join(",", dimensions), String.join(",", values));
|
||||||
}
|
}
|
||||||
|
|
||||||
private String buildTermStr(LLMReq llmReq) {
|
private String buildTermStr(LLMReq llmReq) {
|
||||||
@@ -176,4 +187,16 @@ public class PromptHelper {
|
|||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static boolean isSupportWith(String type, String version) {
|
||||||
|
if (type.equalsIgnoreCase(EngineType.MYSQL.getName()) && Objects.nonNull(version)
|
||||||
|
&& StringUtil.compareVersion(version, "8.0") < 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (type.equalsIgnoreCase(EngineType.CLICKHOUSE.getName()) && Objects.nonNull(version)
|
||||||
|
&& StringUtil.compareVersion(version, "20.4") < 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ public class LLMReq {
|
|||||||
@Data
|
@Data
|
||||||
public static class LLMSchema {
|
public static class LLMSchema {
|
||||||
private String databaseType;
|
private String databaseType;
|
||||||
|
private String databaseVersion;
|
||||||
private Long dataSetId;
|
private Long dataSetId;
|
||||||
private String dataSetName;
|
private String dataSetName;
|
||||||
private List<SchemaElement> metrics;
|
private List<SchemaElement> metrics;
|
||||||
|
|||||||
@@ -197,6 +197,7 @@ public class SchemaServiceImpl implements SchemaService {
|
|||||||
DatabaseResp databaseResp = databaseService
|
DatabaseResp databaseResp = databaseService
|
||||||
.getDatabase(dataSetSchemaResp.getModelResps().get(0).getDatabaseId());
|
.getDatabase(dataSetSchemaResp.getModelResps().get(0).getDatabaseId());
|
||||||
dataSetSchemaResp.setDatabaseType(databaseResp.getType());
|
dataSetSchemaResp.setDatabaseType(databaseResp.getType());
|
||||||
|
dataSetSchemaResp.setDatabaseVersion(databaseResp.getVersion());
|
||||||
}
|
}
|
||||||
dataSetSchemaResps.add(dataSetSchemaResp);
|
dataSetSchemaResps.add(dataSetSchemaResp);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ public class DataSetSchemaBuilder {
|
|||||||
.bizName(resp.getBizName()).type(SchemaElementType.DATASET).build();
|
.bizName(resp.getBizName()).type(SchemaElementType.DATASET).build();
|
||||||
dataSetSchema.setDataSet(dataSet);
|
dataSetSchema.setDataSet(dataSet);
|
||||||
dataSetSchema.setDatabaseType(resp.getDatabaseType());
|
dataSetSchema.setDatabaseType(resp.getDatabaseType());
|
||||||
|
dataSetSchema.setDatabaseVersion(resp.getDatabaseVersion());
|
||||||
|
|
||||||
Set<SchemaElement> metrics = getMetrics(resp);
|
Set<SchemaElement> metrics = getMetrics(resp);
|
||||||
dataSetSchema.getMetrics().addAll(metrics);
|
dataSetSchema.getMetrics().addAll(metrics);
|
||||||
|
|||||||
@@ -124,12 +124,13 @@ public abstract class S2BaseDemo implements CommandLineRunner {
|
|||||||
DatabaseReq databaseReq = new DatabaseReq();
|
DatabaseReq databaseReq = new DatabaseReq();
|
||||||
databaseReq.setName("S2数据库DEMO");
|
databaseReq.setName("S2数据库DEMO");
|
||||||
databaseReq.setDescription("样例数据库实例仅用于体验");
|
databaseReq.setDescription("样例数据库实例仅用于体验");
|
||||||
databaseReq.setType(DataType.H2.getFeature());
|
databaseReq.setType(DataType.H2.toString());
|
||||||
if ("org.postgresql.Driver".equals(driverClassName)) {
|
if ("org.postgresql.Driver".equals(driverClassName)) {
|
||||||
databaseReq.setType(DataType.POSTGRESQL.getFeature());
|
databaseReq.setType(DataType.POSTGRESQL.toString());
|
||||||
} else if ("com.mysql.cj.jdbc.Driver".equals(driverClassName)
|
} else if ("com.mysql.cj.jdbc.Driver".equals(driverClassName)
|
||||||
|| "com.mysql.jdbc.Driver".equals(driverClassName)) {
|
|| "com.mysql.jdbc.Driver".equals(driverClassName)) {
|
||||||
databaseReq.setType(DataType.MYSQL.getFeature());
|
databaseReq.setType(DataType.MYSQL.toString());
|
||||||
|
databaseReq.setVersion("5.7");
|
||||||
}
|
}
|
||||||
databaseReq.setUrl(url);
|
databaseReq.setUrl(url);
|
||||||
databaseReq.setUsername(dataSourceProperties.getUsername());
|
databaseReq.setUsername(dataSourceProperties.getUsername());
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
spring:
|
spring:
|
||||||
datasource:
|
datasource:
|
||||||
|
driver-class-name: org.postgresql.Driver
|
||||||
url: jdbc:postgresql://${DB_HOST}:${DB_PORT:5432}/${DB_NAME}?stringtype=unspecified
|
url: jdbc:postgresql://${DB_HOST}:${DB_PORT:5432}/${DB_NAME}?stringtype=unspecified
|
||||||
username: ${DB_USERNAME}
|
username: ${DB_USERNAME}
|
||||||
password: ${DB_PASSWORD}
|
password: ${DB_PASSWORD}
|
||||||
driver-class-name: org.postgresql.Driver
|
|
||||||
sql:
|
sql:
|
||||||
init:
|
init:
|
||||||
enabled: false
|
|
||||||
mode: always
|
mode: always
|
||||||
username: ${DB_USERNAME}
|
username: ${DB_USERNAME}
|
||||||
password: ${DB_PASSWORD}
|
password: ${DB_PASSWORD}
|
||||||
|
|||||||
Reference in New Issue
Block a user