(improvement)(Headless) Refactor the SemanticModeller to rule first and then llm, and automatically infer field types in the rule method. (#1900)

Co-authored-by: lxwcodemonkey
This commit is contained in:
LXW
2024-11-11 00:10:58 +08:00
committed by GitHub
parent ea6a9ebc5f
commit 87729956e8
12 changed files with 101 additions and 23 deletions

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.headless.api.pojo;
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@@ -14,4 +15,6 @@ public class DBColumn {
private String dataType;
private String comment;
private FieldType fieldType;
}

View File

@@ -1,5 +1,5 @@
package com.tencent.supersonic.headless.api.pojo.enums;
public enum FieldType {
primary_key, foreign_key, data_time, dimension, measure;
primary_key, foreign_key, partition_time, time, dimension, measure;
}

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.core.adaptor.db;
import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.api.pojo.DBColumn;
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
import com.tencent.supersonic.headless.core.pojo.ConnectInfo;
import lombok.extern.slf4j.Slf4j;
@@ -71,7 +72,8 @@ public abstract class BaseDbAdaptor implements DbAdaptor {
String columnName = columns.getString("COLUMN_NAME");
String dataType = columns.getString("TYPE_NAME");
String remarks = columns.getString("REMARKS");
dbColumns.add(new DBColumn(columnName, dataType, remarks));
FieldType fieldType = classifyColumnType(dataType);
dbColumns.add(new DBColumn(columnName, dataType, remarks, fieldType));
}
return dbColumns;
}
@@ -82,4 +84,25 @@ public abstract class BaseDbAdaptor implements DbAdaptor {
return connection.getMetaData();
}
protected static FieldType classifyColumnType(String typeName) {
switch (typeName.toUpperCase()) {
case "INT":
case "INTEGER":
case "BIGINT":
case "SMALLINT":
case "TINYINT":
case "FLOAT":
case "DOUBLE":
case "DECIMAL":
case "NUMERIC":
return FieldType.measure;
case "DATE":
case "TIME":
case "TIMESTAMP":
return FieldType.time;
default:
return FieldType.dimension;
}
}
}

View File

@@ -4,6 +4,7 @@ import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.headless.api.pojo.DBColumn;
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
import com.tencent.supersonic.headless.core.pojo.ConnectInfo;
import lombok.extern.slf4j.Slf4j;
@@ -54,7 +55,8 @@ public class H2Adaptor extends BaseDbAdaptor {
String columnName = columns.getString("COLUMN_NAME");
String dataType = columns.getString("TYPE_NAME");
String remarks = columns.getString("REMARKS");
dbColumns.add(new DBColumn(columnName, dataType, remarks));
FieldType fieldType = classifyColumnType(dataType);
dbColumns.add(new DBColumn(columnName, dataType, remarks, fieldType));
}
return dbColumns;
}

View File

@@ -5,6 +5,7 @@ import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.headless.api.pojo.DBColumn;
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
import com.tencent.supersonic.headless.core.pojo.ConnectInfo;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.StringValue;
@@ -105,8 +106,41 @@ public class PostgresqlAdaptor extends BaseDbAdaptor {
String columnName = columns.getString("COLUMN_NAME");
String dataType = columns.getString("TYPE_NAME");
String remarks = columns.getString("REMARKS");
dbColumns.add(new DBColumn(columnName, dataType, remarks));
FieldType fieldType = classifyColumnType(dataType);
dbColumns.add(new DBColumn(columnName, dataType, remarks, fieldType));
}
return dbColumns;
}
protected static FieldType classifyColumnType(String typeName) {
switch (typeName.toUpperCase()) {
case "INT":
case "INTEGER":
case "BIGINT":
case "SMALLINT":
case "SERIAL":
case "BIGSERIAL":
case "SMALLSERIAL":
case "REAL":
case "DOUBLE PRECISION":
case "NUMERIC":
case "DECIMAL":
return FieldType.measure;
case "DATE":
case "TIME":
case "TIMESTAMP":
case "TIMESTAMPTZ":
case "INTERVAL":
return FieldType.time;
case "VARCHAR":
case "CHAR":
case "TEXT":
case "CHARACTER VARYING":
case "CHARACTER":
case "UUID":
default:
return FieldType.dimension;
}
}
}

View File

@@ -45,7 +45,7 @@ public class LLMSemanticModeller implements SemanticModeller {
+ "\n2. Create a Chinese name for the field and categorize the field into one of the following five types:"
+ "\n primary_key: This is a unique identifier for a record row in a database."
+ "\n foreign_key: This is a key in a database whose value is derived from the primary key of another table."
+ "\n data_time: This represents the time when data is generated in the data warehouse."
+ "\n partition_time: This represents the time when data is generated in the data warehouse."
+ "\n dimension: Usually a string type, used for grouping and filtering data. No need to generate aggregate functions"
+ "\n measure: Usually a numeric type, used to quantify data from a certain evaluative perspective. "
+ " Also, you need to generate aggregate functions(Eg: MAX, MIN, AVG, SUM, COUNT) for the measure type. "
@@ -66,22 +66,23 @@ public class LLMSemanticModeller implements SemanticModeller {
}
@Override
public ModelSchema build(DbSchema dbSchema, List<DbSchema> dbSchemas,
public void build(DbSchema dbSchema, List<DbSchema> dbSchemas, ModelSchema modelSchema,
ModelBuildReq modelBuildReq) {
if (!modelBuildReq.isBuildByLLM()) {
return;
}
Optional<ChatApp> chatApp = ChatAppManager.getApp(APP_KEY);
if (!chatApp.isPresent() || !chatApp.get().isEnable()) {
return null;
return;
}
List<DbSchema> otherDbSchema = getOtherDbSchema(dbSchema, dbSchemas);
ModelSchemaExtractor extractor =
AiServices.create(ModelSchemaExtractor.class, getChatModel(modelBuildReq));
Prompt prompt = generatePrompt(dbSchema, otherDbSchema, chatApp.get());
ModelSchema modelSchema =
extractor.generateModelSchema(prompt.toUserMessage().singleText());
modelSchema = extractor.generateModelSchema(prompt.toUserMessage().singleText());
log.info("dbSchema: {}\n otherRelatedDBSchema:{}\n modelSchema: {}",
JsonUtil.toString(dbSchema), JsonUtil.toString(otherDbSchema),
JsonUtil.toString(modelSchema));
return modelSchema;
}
private List<DbSchema> getOtherDbSchema(DbSchema curSchema, List<DbSchema> dbSchemas) {
@@ -113,7 +114,7 @@ public class LLMSemanticModeller implements SemanticModeller {
Environment environment = ContextUtils.getBean(Environment.class);
String enableExemplarLoading =
environment.getProperty("s2.model.building.exemplars.enabled");
if (Boolean.TRUE.equals(Boolean.parseBoolean(enableExemplarLoading))) {
if (Boolean.FALSE.equals(Boolean.parseBoolean(enableExemplarLoading))) {
log.info("Not enable load model-building exemplars");
return "";
}

View File

@@ -14,13 +14,11 @@ import java.util.stream.Collectors;
public class RuleSemanticModeller implements SemanticModeller {
@Override
public ModelSchema build(DbSchema dbSchema, List<DbSchema> dbSchemas,
public void build(DbSchema dbSchema, List<DbSchema> dbSchemas, ModelSchema modelSchema,
ModelBuildReq modelBuildReq) {
ModelSchema modelSchema = new ModelSchema();
List<ColumnSchema> columnSchemas =
dbSchema.getDbColumns().stream().map(this::convert).collect(Collectors.toList());
modelSchema.setColumnSchemas(columnSchemas);
return modelSchema;
}
private ColumnSchema convert(DBColumn dbColumn) {
@@ -29,6 +27,7 @@ public class RuleSemanticModeller implements SemanticModeller {
columnSchema.setColumnName(dbColumn.getColumnName());
columnSchema.setComment(dbColumn.getComment());
columnSchema.setDataType(dbColumn.getDataType());
columnSchema.setFiledType(dbColumn.getFieldType());
return columnSchema;
}

View File

@@ -9,6 +9,7 @@ import java.util.List;
public interface SemanticModeller {
ModelSchema build(DbSchema dbSchema, List<DbSchema> otherDbSchema, ModelBuildReq modelBuildReq);
void build(DbSchema dbSchema, List<DbSchema> otherDbSchema, ModelSchema modelSchema,
ModelBuildReq modelBuildReq);
}

View File

@@ -222,8 +222,11 @@ public class ModelServiceImpl implements ModelService {
private void doBuild(ModelBuildReq modelBuildReq, DbSchema curSchema, List<DbSchema> dbSchemas,
Map<String, ModelSchema> modelSchemaMap) {
SemanticModeller semanticModeller = CoreComponentFactory.getSemanticModeller();
ModelSchema modelSchema = semanticModeller.build(curSchema, dbSchemas, modelBuildReq);
ModelSchema modelSchema = new ModelSchema();
List<SemanticModeller> semanticModellers = CoreComponentFactory.getSemanticModellers();
for (SemanticModeller semanticModeller : semanticModellers) {
semanticModeller.build(curSchema, dbSchemas, modelSchema, modelBuildReq);
}
modelSchemaMap.put(curSchema.getTable(), modelSchema);
}

View File

@@ -4,15 +4,26 @@ import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
import com.tencent.supersonic.headless.server.modeller.SemanticModeller;
import lombok.extern.slf4j.Slf4j;
import java.util.ArrayList;
import java.util.List;
/**
* QueryConverter QueryOptimizer QueryExecutor object factory
*/
@Slf4j
public class CoreComponentFactory extends ComponentFactory {
private static SemanticModeller semanticModeller;
private static List<SemanticModeller> semanticModellers = new ArrayList<>();
public static SemanticModeller getSemanticModeller() {
return semanticModeller == null ? init(SemanticModeller.class) : semanticModeller;
public static List<SemanticModeller> getSemanticModellers() {
if (semanticModellers.isEmpty()) {
initSemanticModellers();
}
return semanticModellers;
}
private static void initSemanticModellers() {
init(SemanticModeller.class, semanticModellers);
}
}