diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DBColumn.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DBColumn.java index c9be51461..b492e708a 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DBColumn.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DBColumn.java @@ -1,5 +1,6 @@ package com.tencent.supersonic.headless.api.pojo; +import com.tencent.supersonic.headless.api.pojo.enums.FieldType; import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor; @@ -14,4 +15,6 @@ public class DBColumn { private String dataType; private String comment; + + private FieldType fieldType; } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/FieldType.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/FieldType.java index f14fa5297..ce24f76c9 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/FieldType.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/FieldType.java @@ -1,5 +1,5 @@ package com.tencent.supersonic.headless.api.pojo.enums; public enum FieldType { - primary_key, foreign_key, data_time, dimension, measure; + primary_key, foreign_key, partition_time, time, dimension, measure; } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/BaseDbAdaptor.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/BaseDbAdaptor.java index 9b2f1029e..b03521996 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/BaseDbAdaptor.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/BaseDbAdaptor.java @@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.core.adaptor.db; import com.google.common.collect.Lists; import com.tencent.supersonic.headless.api.pojo.DBColumn; +import com.tencent.supersonic.headless.api.pojo.enums.FieldType; import com.tencent.supersonic.headless.core.pojo.ConnectInfo; import lombok.extern.slf4j.Slf4j; @@ -71,7 +72,8 @@ public abstract class BaseDbAdaptor implements DbAdaptor { String columnName = columns.getString("COLUMN_NAME"); String dataType = columns.getString("TYPE_NAME"); String remarks = columns.getString("REMARKS"); - dbColumns.add(new DBColumn(columnName, dataType, remarks)); + FieldType fieldType = classifyColumnType(dataType); + dbColumns.add(new DBColumn(columnName, dataType, remarks, fieldType)); } return dbColumns; } @@ -82,4 +84,25 @@ public abstract class BaseDbAdaptor implements DbAdaptor { return connection.getMetaData(); } + protected static FieldType classifyColumnType(String typeName) { + switch (typeName.toUpperCase()) { + case "INT": + case "INTEGER": + case "BIGINT": + case "SMALLINT": + case "TINYINT": + case "FLOAT": + case "DOUBLE": + case "DECIMAL": + case "NUMERIC": + return FieldType.measure; + case "DATE": + case "TIME": + case "TIMESTAMP": + return FieldType.time; + default: + return FieldType.dimension; + } + } + } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/H2Adaptor.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/H2Adaptor.java index 9fe3ca56c..7ae1e7cc1 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/H2Adaptor.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/H2Adaptor.java @@ -4,6 +4,7 @@ import com.google.common.collect.Lists; import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.headless.api.pojo.DBColumn; +import com.tencent.supersonic.headless.api.pojo.enums.FieldType; import com.tencent.supersonic.headless.core.pojo.ConnectInfo; import lombok.extern.slf4j.Slf4j; @@ -54,7 +55,8 @@ public class H2Adaptor extends BaseDbAdaptor { String columnName = columns.getString("COLUMN_NAME"); String dataType = columns.getString("TYPE_NAME"); String remarks = columns.getString("REMARKS"); - dbColumns.add(new DBColumn(columnName, dataType, remarks)); + FieldType fieldType = classifyColumnType(dataType); + dbColumns.add(new DBColumn(columnName, dataType, remarks, fieldType)); } return dbColumns; } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/PostgresqlAdaptor.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/PostgresqlAdaptor.java index 5e7c1e3e0..42cff0c5b 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/PostgresqlAdaptor.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/PostgresqlAdaptor.java @@ -5,6 +5,7 @@ import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper; import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.headless.api.pojo.DBColumn; +import com.tencent.supersonic.headless.api.pojo.enums.FieldType; import com.tencent.supersonic.headless.core.pojo.ConnectInfo; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.expression.StringValue; @@ -105,8 +106,41 @@ public class PostgresqlAdaptor extends BaseDbAdaptor { String columnName = columns.getString("COLUMN_NAME"); String dataType = columns.getString("TYPE_NAME"); String remarks = columns.getString("REMARKS"); - dbColumns.add(new DBColumn(columnName, dataType, remarks)); + FieldType fieldType = classifyColumnType(dataType); + dbColumns.add(new DBColumn(columnName, dataType, remarks, fieldType)); } return dbColumns; } + + protected static FieldType classifyColumnType(String typeName) { + switch (typeName.toUpperCase()) { + case "INT": + case "INTEGER": + case "BIGINT": + case "SMALLINT": + case "SERIAL": + case "BIGSERIAL": + case "SMALLSERIAL": + case "REAL": + case "DOUBLE PRECISION": + case "NUMERIC": + case "DECIMAL": + return FieldType.measure; + case "DATE": + case "TIME": + case "TIMESTAMP": + case "TIMESTAMPTZ": + case "INTERVAL": + return FieldType.time; + case "VARCHAR": + case "CHAR": + case "TEXT": + case "CHARACTER VARYING": + case "CHARACTER": + case "UUID": + default: + return FieldType.dimension; + } + } + } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/modeller/LLMSemanticModeller.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/modeller/LLMSemanticModeller.java index 90440339a..b601fb252 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/modeller/LLMSemanticModeller.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/modeller/LLMSemanticModeller.java @@ -45,7 +45,7 @@ public class LLMSemanticModeller implements SemanticModeller { + "\n2. Create a Chinese name for the field and categorize the field into one of the following five types:" + "\n primary_key: This is a unique identifier for a record row in a database." + "\n foreign_key: This is a key in a database whose value is derived from the primary key of another table." - + "\n data_time: This represents the time when data is generated in the data warehouse." + + "\n partition_time: This represents the time when data is generated in the data warehouse." + "\n dimension: Usually a string type, used for grouping and filtering data. No need to generate aggregate functions" + "\n measure: Usually a numeric type, used to quantify data from a certain evaluative perspective. " + " Also, you need to generate aggregate functions(Eg: MAX, MIN, AVG, SUM, COUNT) for the measure type. " @@ -66,22 +66,23 @@ public class LLMSemanticModeller implements SemanticModeller { } @Override - public ModelSchema build(DbSchema dbSchema, List dbSchemas, + public void build(DbSchema dbSchema, List dbSchemas, ModelSchema modelSchema, ModelBuildReq modelBuildReq) { + if (!modelBuildReq.isBuildByLLM()) { + return; + } Optional chatApp = ChatAppManager.getApp(APP_KEY); if (!chatApp.isPresent() || !chatApp.get().isEnable()) { - return null; + return; } List otherDbSchema = getOtherDbSchema(dbSchema, dbSchemas); ModelSchemaExtractor extractor = AiServices.create(ModelSchemaExtractor.class, getChatModel(modelBuildReq)); Prompt prompt = generatePrompt(dbSchema, otherDbSchema, chatApp.get()); - ModelSchema modelSchema = - extractor.generateModelSchema(prompt.toUserMessage().singleText()); + modelSchema = extractor.generateModelSchema(prompt.toUserMessage().singleText()); log.info("dbSchema: {}\n otherRelatedDBSchema:{}\n modelSchema: {}", JsonUtil.toString(dbSchema), JsonUtil.toString(otherDbSchema), JsonUtil.toString(modelSchema)); - return modelSchema; } private List getOtherDbSchema(DbSchema curSchema, List dbSchemas) { @@ -113,7 +114,7 @@ public class LLMSemanticModeller implements SemanticModeller { Environment environment = ContextUtils.getBean(Environment.class); String enableExemplarLoading = environment.getProperty("s2.model.building.exemplars.enabled"); - if (Boolean.TRUE.equals(Boolean.parseBoolean(enableExemplarLoading))) { + if (Boolean.FALSE.equals(Boolean.parseBoolean(enableExemplarLoading))) { log.info("Not enable load model-building exemplars"); return ""; } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/modeller/RuleSemanticModeller.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/modeller/RuleSemanticModeller.java index c008eabd3..b27c25169 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/modeller/RuleSemanticModeller.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/modeller/RuleSemanticModeller.java @@ -14,13 +14,11 @@ import java.util.stream.Collectors; public class RuleSemanticModeller implements SemanticModeller { @Override - public ModelSchema build(DbSchema dbSchema, List dbSchemas, + public void build(DbSchema dbSchema, List dbSchemas, ModelSchema modelSchema, ModelBuildReq modelBuildReq) { - ModelSchema modelSchema = new ModelSchema(); List columnSchemas = dbSchema.getDbColumns().stream().map(this::convert).collect(Collectors.toList()); modelSchema.setColumnSchemas(columnSchemas); - return modelSchema; } private ColumnSchema convert(DBColumn dbColumn) { @@ -29,6 +27,7 @@ public class RuleSemanticModeller implements SemanticModeller { columnSchema.setColumnName(dbColumn.getColumnName()); columnSchema.setComment(dbColumn.getComment()); columnSchema.setDataType(dbColumn.getDataType()); + columnSchema.setFiledType(dbColumn.getFieldType()); return columnSchema; } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/modeller/SemanticModeller.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/modeller/SemanticModeller.java index e1a287581..c8a15cd96 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/modeller/SemanticModeller.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/modeller/SemanticModeller.java @@ -9,6 +9,7 @@ import java.util.List; public interface SemanticModeller { - ModelSchema build(DbSchema dbSchema, List otherDbSchema, ModelBuildReq modelBuildReq); + void build(DbSchema dbSchema, List otherDbSchema, ModelSchema modelSchema, + ModelBuildReq modelBuildReq); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java index 75594c745..f59d7e64e 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java @@ -222,8 +222,11 @@ public class ModelServiceImpl implements ModelService { private void doBuild(ModelBuildReq modelBuildReq, DbSchema curSchema, List dbSchemas, Map modelSchemaMap) { - SemanticModeller semanticModeller = CoreComponentFactory.getSemanticModeller(); - ModelSchema modelSchema = semanticModeller.build(curSchema, dbSchemas, modelBuildReq); + ModelSchema modelSchema = new ModelSchema(); + List semanticModellers = CoreComponentFactory.getSemanticModellers(); + for (SemanticModeller semanticModeller : semanticModellers) { + semanticModeller.build(curSchema, dbSchemas, modelSchema, modelBuildReq); + } modelSchemaMap.put(curSchema.getTable(), modelSchema); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/CoreComponentFactory.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/CoreComponentFactory.java index 602062551..d480c9fe5 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/CoreComponentFactory.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/CoreComponentFactory.java @@ -4,15 +4,26 @@ import com.tencent.supersonic.headless.chat.utils.ComponentFactory; import com.tencent.supersonic.headless.server.modeller.SemanticModeller; import lombok.extern.slf4j.Slf4j; +import java.util.ArrayList; +import java.util.List; + /** * QueryConverter QueryOptimizer QueryExecutor object factory */ @Slf4j public class CoreComponentFactory extends ComponentFactory { - private static SemanticModeller semanticModeller; + private static List semanticModellers = new ArrayList<>(); - public static SemanticModeller getSemanticModeller() { - return semanticModeller == null ? init(SemanticModeller.class) : semanticModeller; + public static List getSemanticModellers() { + if (semanticModellers.isEmpty()) { + initSemanticModellers(); + } + return semanticModellers; } + + private static void initSemanticModellers() { + init(SemanticModeller.class, semanticModellers); + } + } diff --git a/launchers/standalone/src/main/resources/META-INF/spring.factories b/launchers/standalone/src/main/resources/META-INF/spring.factories index a05c801e2..0a989e182 100644 --- a/launchers/standalone/src/main/resources/META-INF/spring.factories +++ b/launchers/standalone/src/main/resources/META-INF/spring.factories @@ -46,6 +46,7 @@ com.tencent.supersonic.headless.core.cache.QueryCache=\ ### headless-server SPIs com.tencent.supersonic.headless.server.modeller.SemanticModeller=\ + com.tencent.supersonic.headless.server.modeller.RuleSemanticModeller, \ com.tencent.supersonic.headless.server.modeller.LLMSemanticModeller ### chat-server SPIs diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/LLMSemanticModellerTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/SemanticModellerTest.java similarity index 96% rename from launchers/standalone/src/test/java/com/tencent/supersonic/headless/LLMSemanticModellerTest.java rename to launchers/standalone/src/test/java/com/tencent/supersonic/headless/SemanticModellerTest.java index fb8d39eeb..a8793fb91 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/LLMSemanticModellerTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/SemanticModellerTest.java @@ -20,7 +20,7 @@ import java.util.Map; @Disabled @TestPropertySource(properties = {"s2.model.building.exemplars.enabled = false"}) -public class LLMSemanticModellerTest extends BaseTest { +public class SemanticModellerTest extends BaseTest { private LLMConfigUtils.LLMType llmType = LLMConfigUtils.LLMType.OLLAMA_LLAMA3; @@ -49,7 +49,7 @@ public class LLMSemanticModellerTest extends BaseTest { Assertions.assertEquals(4, stayTimeModelSchema.getColumnSchemas().size()); Assertions.assertEquals(FieldType.foreign_key, stayTimeModelSchema.getColumnByName("user_name").getFiledType()); - Assertions.assertEquals(FieldType.data_time, + Assertions.assertEquals(FieldType.partition_time, stayTimeModelSchema.getColumnByName("imp_date").getFiledType()); Assertions.assertEquals(FieldType.dimension, stayTimeModelSchema.getColumnByName("page").getFiledType()); @@ -73,7 +73,7 @@ public class LLMSemanticModellerTest extends BaseTest { ModelSchema pvModelSchema = modelSchemaMap.values().iterator().next(); Assertions.assertEquals(5, pvModelSchema.getColumnSchemas().size()); - Assertions.assertEquals(FieldType.data_time, + Assertions.assertEquals(FieldType.partition_time, pvModelSchema.getColumnByName("imp_date").getFiledType()); Assertions.assertEquals(FieldType.dimension, pvModelSchema.getColumnByName("user_name").getFiledType());