mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 03:58:14 +00:00
(improvement)(Headless) Refactor the SemanticModeller to rule first and then llm, and automatically infer field types in the rule method. (#1900)
Co-authored-by: lxwcodemonkey
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
package com.tencent.supersonic.headless.api.pojo;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
@@ -14,4 +15,6 @@ public class DBColumn {
|
||||
private String dataType;
|
||||
|
||||
private String comment;
|
||||
|
||||
private FieldType fieldType;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
package com.tencent.supersonic.headless.api.pojo.enums;
|
||||
|
||||
public enum FieldType {
|
||||
primary_key, foreign_key, data_time, dimension, measure;
|
||||
primary_key, foreign_key, partition_time, time, dimension, measure;
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.core.adaptor.db;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.headless.api.pojo.DBColumn;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
|
||||
import com.tencent.supersonic.headless.core.pojo.ConnectInfo;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@@ -71,7 +72,8 @@ public abstract class BaseDbAdaptor implements DbAdaptor {
|
||||
String columnName = columns.getString("COLUMN_NAME");
|
||||
String dataType = columns.getString("TYPE_NAME");
|
||||
String remarks = columns.getString("REMARKS");
|
||||
dbColumns.add(new DBColumn(columnName, dataType, remarks));
|
||||
FieldType fieldType = classifyColumnType(dataType);
|
||||
dbColumns.add(new DBColumn(columnName, dataType, remarks, fieldType));
|
||||
}
|
||||
return dbColumns;
|
||||
}
|
||||
@@ -82,4 +84,25 @@ public abstract class BaseDbAdaptor implements DbAdaptor {
|
||||
return connection.getMetaData();
|
||||
}
|
||||
|
||||
protected static FieldType classifyColumnType(String typeName) {
|
||||
switch (typeName.toUpperCase()) {
|
||||
case "INT":
|
||||
case "INTEGER":
|
||||
case "BIGINT":
|
||||
case "SMALLINT":
|
||||
case "TINYINT":
|
||||
case "FLOAT":
|
||||
case "DOUBLE":
|
||||
case "DECIMAL":
|
||||
case "NUMERIC":
|
||||
return FieldType.measure;
|
||||
case "DATE":
|
||||
case "TIME":
|
||||
case "TIMESTAMP":
|
||||
return FieldType.time;
|
||||
default:
|
||||
return FieldType.dimension;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.DBColumn;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
|
||||
import com.tencent.supersonic.headless.core.pojo.ConnectInfo;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@@ -54,7 +55,8 @@ public class H2Adaptor extends BaseDbAdaptor {
|
||||
String columnName = columns.getString("COLUMN_NAME");
|
||||
String dataType = columns.getString("TYPE_NAME");
|
||||
String remarks = columns.getString("REMARKS");
|
||||
dbColumns.add(new DBColumn(columnName, dataType, remarks));
|
||||
FieldType fieldType = classifyColumnType(dataType);
|
||||
dbColumns.add(new DBColumn(columnName, dataType, remarks, fieldType));
|
||||
}
|
||||
return dbColumns;
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.DBColumn;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
|
||||
import com.tencent.supersonic.headless.core.pojo.ConnectInfo;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.StringValue;
|
||||
@@ -105,8 +106,41 @@ public class PostgresqlAdaptor extends BaseDbAdaptor {
|
||||
String columnName = columns.getString("COLUMN_NAME");
|
||||
String dataType = columns.getString("TYPE_NAME");
|
||||
String remarks = columns.getString("REMARKS");
|
||||
dbColumns.add(new DBColumn(columnName, dataType, remarks));
|
||||
FieldType fieldType = classifyColumnType(dataType);
|
||||
dbColumns.add(new DBColumn(columnName, dataType, remarks, fieldType));
|
||||
}
|
||||
return dbColumns;
|
||||
}
|
||||
|
||||
protected static FieldType classifyColumnType(String typeName) {
|
||||
switch (typeName.toUpperCase()) {
|
||||
case "INT":
|
||||
case "INTEGER":
|
||||
case "BIGINT":
|
||||
case "SMALLINT":
|
||||
case "SERIAL":
|
||||
case "BIGSERIAL":
|
||||
case "SMALLSERIAL":
|
||||
case "REAL":
|
||||
case "DOUBLE PRECISION":
|
||||
case "NUMERIC":
|
||||
case "DECIMAL":
|
||||
return FieldType.measure;
|
||||
case "DATE":
|
||||
case "TIME":
|
||||
case "TIMESTAMP":
|
||||
case "TIMESTAMPTZ":
|
||||
case "INTERVAL":
|
||||
return FieldType.time;
|
||||
case "VARCHAR":
|
||||
case "CHAR":
|
||||
case "TEXT":
|
||||
case "CHARACTER VARYING":
|
||||
case "CHARACTER":
|
||||
case "UUID":
|
||||
default:
|
||||
return FieldType.dimension;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ public class LLMSemanticModeller implements SemanticModeller {
|
||||
+ "\n2. Create a Chinese name for the field and categorize the field into one of the following five types:"
|
||||
+ "\n primary_key: This is a unique identifier for a record row in a database."
|
||||
+ "\n foreign_key: This is a key in a database whose value is derived from the primary key of another table."
|
||||
+ "\n data_time: This represents the time when data is generated in the data warehouse."
|
||||
+ "\n partition_time: This represents the time when data is generated in the data warehouse."
|
||||
+ "\n dimension: Usually a string type, used for grouping and filtering data. No need to generate aggregate functions"
|
||||
+ "\n measure: Usually a numeric type, used to quantify data from a certain evaluative perspective. "
|
||||
+ " Also, you need to generate aggregate functions(Eg: MAX, MIN, AVG, SUM, COUNT) for the measure type. "
|
||||
@@ -66,22 +66,23 @@ public class LLMSemanticModeller implements SemanticModeller {
|
||||
}
|
||||
|
||||
@Override
|
||||
public ModelSchema build(DbSchema dbSchema, List<DbSchema> dbSchemas,
|
||||
public void build(DbSchema dbSchema, List<DbSchema> dbSchemas, ModelSchema modelSchema,
|
||||
ModelBuildReq modelBuildReq) {
|
||||
if (!modelBuildReq.isBuildByLLM()) {
|
||||
return;
|
||||
}
|
||||
Optional<ChatApp> chatApp = ChatAppManager.getApp(APP_KEY);
|
||||
if (!chatApp.isPresent() || !chatApp.get().isEnable()) {
|
||||
return null;
|
||||
return;
|
||||
}
|
||||
List<DbSchema> otherDbSchema = getOtherDbSchema(dbSchema, dbSchemas);
|
||||
ModelSchemaExtractor extractor =
|
||||
AiServices.create(ModelSchemaExtractor.class, getChatModel(modelBuildReq));
|
||||
Prompt prompt = generatePrompt(dbSchema, otherDbSchema, chatApp.get());
|
||||
ModelSchema modelSchema =
|
||||
extractor.generateModelSchema(prompt.toUserMessage().singleText());
|
||||
modelSchema = extractor.generateModelSchema(prompt.toUserMessage().singleText());
|
||||
log.info("dbSchema: {}\n otherRelatedDBSchema:{}\n modelSchema: {}",
|
||||
JsonUtil.toString(dbSchema), JsonUtil.toString(otherDbSchema),
|
||||
JsonUtil.toString(modelSchema));
|
||||
return modelSchema;
|
||||
}
|
||||
|
||||
private List<DbSchema> getOtherDbSchema(DbSchema curSchema, List<DbSchema> dbSchemas) {
|
||||
@@ -113,7 +114,7 @@ public class LLMSemanticModeller implements SemanticModeller {
|
||||
Environment environment = ContextUtils.getBean(Environment.class);
|
||||
String enableExemplarLoading =
|
||||
environment.getProperty("s2.model.building.exemplars.enabled");
|
||||
if (Boolean.TRUE.equals(Boolean.parseBoolean(enableExemplarLoading))) {
|
||||
if (Boolean.FALSE.equals(Boolean.parseBoolean(enableExemplarLoading))) {
|
||||
log.info("Not enable load model-building exemplars");
|
||||
return "";
|
||||
}
|
||||
|
||||
@@ -14,13 +14,11 @@ import java.util.stream.Collectors;
|
||||
public class RuleSemanticModeller implements SemanticModeller {
|
||||
|
||||
@Override
|
||||
public ModelSchema build(DbSchema dbSchema, List<DbSchema> dbSchemas,
|
||||
public void build(DbSchema dbSchema, List<DbSchema> dbSchemas, ModelSchema modelSchema,
|
||||
ModelBuildReq modelBuildReq) {
|
||||
ModelSchema modelSchema = new ModelSchema();
|
||||
List<ColumnSchema> columnSchemas =
|
||||
dbSchema.getDbColumns().stream().map(this::convert).collect(Collectors.toList());
|
||||
modelSchema.setColumnSchemas(columnSchemas);
|
||||
return modelSchema;
|
||||
}
|
||||
|
||||
private ColumnSchema convert(DBColumn dbColumn) {
|
||||
@@ -29,6 +27,7 @@ public class RuleSemanticModeller implements SemanticModeller {
|
||||
columnSchema.setColumnName(dbColumn.getColumnName());
|
||||
columnSchema.setComment(dbColumn.getComment());
|
||||
columnSchema.setDataType(dbColumn.getDataType());
|
||||
columnSchema.setFiledType(dbColumn.getFieldType());
|
||||
return columnSchema;
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import java.util.List;
|
||||
|
||||
public interface SemanticModeller {
|
||||
|
||||
ModelSchema build(DbSchema dbSchema, List<DbSchema> otherDbSchema, ModelBuildReq modelBuildReq);
|
||||
void build(DbSchema dbSchema, List<DbSchema> otherDbSchema, ModelSchema modelSchema,
|
||||
ModelBuildReq modelBuildReq);
|
||||
|
||||
}
|
||||
|
||||
@@ -222,8 +222,11 @@ public class ModelServiceImpl implements ModelService {
|
||||
|
||||
private void doBuild(ModelBuildReq modelBuildReq, DbSchema curSchema, List<DbSchema> dbSchemas,
|
||||
Map<String, ModelSchema> modelSchemaMap) {
|
||||
SemanticModeller semanticModeller = CoreComponentFactory.getSemanticModeller();
|
||||
ModelSchema modelSchema = semanticModeller.build(curSchema, dbSchemas, modelBuildReq);
|
||||
ModelSchema modelSchema = new ModelSchema();
|
||||
List<SemanticModeller> semanticModellers = CoreComponentFactory.getSemanticModellers();
|
||||
for (SemanticModeller semanticModeller : semanticModellers) {
|
||||
semanticModeller.build(curSchema, dbSchemas, modelSchema, modelBuildReq);
|
||||
}
|
||||
modelSchemaMap.put(curSchema.getTable(), modelSchema);
|
||||
}
|
||||
|
||||
|
||||
@@ -4,15 +4,26 @@ import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.headless.server.modeller.SemanticModeller;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* QueryConverter QueryOptimizer QueryExecutor object factory
|
||||
*/
|
||||
@Slf4j
|
||||
public class CoreComponentFactory extends ComponentFactory {
|
||||
|
||||
private static SemanticModeller semanticModeller;
|
||||
private static List<SemanticModeller> semanticModellers = new ArrayList<>();
|
||||
|
||||
public static SemanticModeller getSemanticModeller() {
|
||||
return semanticModeller == null ? init(SemanticModeller.class) : semanticModeller;
|
||||
public static List<SemanticModeller> getSemanticModellers() {
|
||||
if (semanticModellers.isEmpty()) {
|
||||
initSemanticModellers();
|
||||
}
|
||||
return semanticModellers;
|
||||
}
|
||||
|
||||
private static void initSemanticModellers() {
|
||||
init(SemanticModeller.class, semanticModellers);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -46,6 +46,7 @@ com.tencent.supersonic.headless.core.cache.QueryCache=\
|
||||
### headless-server SPIs
|
||||
|
||||
com.tencent.supersonic.headless.server.modeller.SemanticModeller=\
|
||||
com.tencent.supersonic.headless.server.modeller.RuleSemanticModeller, \
|
||||
com.tencent.supersonic.headless.server.modeller.LLMSemanticModeller
|
||||
|
||||
### chat-server SPIs
|
||||
|
||||
@@ -20,7 +20,7 @@ import java.util.Map;
|
||||
|
||||
@Disabled
|
||||
@TestPropertySource(properties = {"s2.model.building.exemplars.enabled = false"})
|
||||
public class LLMSemanticModellerTest extends BaseTest {
|
||||
public class SemanticModellerTest extends BaseTest {
|
||||
|
||||
private LLMConfigUtils.LLMType llmType = LLMConfigUtils.LLMType.OLLAMA_LLAMA3;
|
||||
|
||||
@@ -49,7 +49,7 @@ public class LLMSemanticModellerTest extends BaseTest {
|
||||
Assertions.assertEquals(4, stayTimeModelSchema.getColumnSchemas().size());
|
||||
Assertions.assertEquals(FieldType.foreign_key,
|
||||
stayTimeModelSchema.getColumnByName("user_name").getFiledType());
|
||||
Assertions.assertEquals(FieldType.data_time,
|
||||
Assertions.assertEquals(FieldType.partition_time,
|
||||
stayTimeModelSchema.getColumnByName("imp_date").getFiledType());
|
||||
Assertions.assertEquals(FieldType.dimension,
|
||||
stayTimeModelSchema.getColumnByName("page").getFiledType());
|
||||
@@ -73,7 +73,7 @@ public class LLMSemanticModellerTest extends BaseTest {
|
||||
|
||||
ModelSchema pvModelSchema = modelSchemaMap.values().iterator().next();
|
||||
Assertions.assertEquals(5, pvModelSchema.getColumnSchemas().size());
|
||||
Assertions.assertEquals(FieldType.data_time,
|
||||
Assertions.assertEquals(FieldType.partition_time,
|
||||
pvModelSchema.getColumnByName("imp_date").getFiledType());
|
||||
Assertions.assertEquals(FieldType.dimension,
|
||||
pvModelSchema.getColumnByName("user_name").getFiledType());
|
||||
Reference in New Issue
Block a user