mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-13 13:07:32 +00:00
(improvement)(Headless) Refactor the SemanticModeller to rule first and then llm, and automatically infer field types in the rule method. (#1900)
Co-authored-by: lxwcodemonkey
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
package com.tencent.supersonic.headless.api.pojo;
|
package com.tencent.supersonic.headless.api.pojo;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
@@ -14,4 +15,6 @@ public class DBColumn {
|
|||||||
private String dataType;
|
private String dataType;
|
||||||
|
|
||||||
private String comment;
|
private String comment;
|
||||||
|
|
||||||
|
private FieldType fieldType;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
package com.tencent.supersonic.headless.api.pojo.enums;
|
package com.tencent.supersonic.headless.api.pojo.enums;
|
||||||
|
|
||||||
public enum FieldType {
|
public enum FieldType {
|
||||||
primary_key, foreign_key, data_time, dimension, measure;
|
primary_key, foreign_key, partition_time, time, dimension, measure;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.core.adaptor.db;
|
|||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.headless.api.pojo.DBColumn;
|
import com.tencent.supersonic.headless.api.pojo.DBColumn;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
|
||||||
import com.tencent.supersonic.headless.core.pojo.ConnectInfo;
|
import com.tencent.supersonic.headless.core.pojo.ConnectInfo;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
@@ -71,7 +72,8 @@ public abstract class BaseDbAdaptor implements DbAdaptor {
|
|||||||
String columnName = columns.getString("COLUMN_NAME");
|
String columnName = columns.getString("COLUMN_NAME");
|
||||||
String dataType = columns.getString("TYPE_NAME");
|
String dataType = columns.getString("TYPE_NAME");
|
||||||
String remarks = columns.getString("REMARKS");
|
String remarks = columns.getString("REMARKS");
|
||||||
dbColumns.add(new DBColumn(columnName, dataType, remarks));
|
FieldType fieldType = classifyColumnType(dataType);
|
||||||
|
dbColumns.add(new DBColumn(columnName, dataType, remarks, fieldType));
|
||||||
}
|
}
|
||||||
return dbColumns;
|
return dbColumns;
|
||||||
}
|
}
|
||||||
@@ -82,4 +84,25 @@ public abstract class BaseDbAdaptor implements DbAdaptor {
|
|||||||
return connection.getMetaData();
|
return connection.getMetaData();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected static FieldType classifyColumnType(String typeName) {
|
||||||
|
switch (typeName.toUpperCase()) {
|
||||||
|
case "INT":
|
||||||
|
case "INTEGER":
|
||||||
|
case "BIGINT":
|
||||||
|
case "SMALLINT":
|
||||||
|
case "TINYINT":
|
||||||
|
case "FLOAT":
|
||||||
|
case "DOUBLE":
|
||||||
|
case "DECIMAL":
|
||||||
|
case "NUMERIC":
|
||||||
|
return FieldType.measure;
|
||||||
|
case "DATE":
|
||||||
|
case "TIME":
|
||||||
|
case "TIMESTAMP":
|
||||||
|
return FieldType.time;
|
||||||
|
default:
|
||||||
|
return FieldType.dimension;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import com.google.common.collect.Lists;
|
|||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||||
import com.tencent.supersonic.headless.api.pojo.DBColumn;
|
import com.tencent.supersonic.headless.api.pojo.DBColumn;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
|
||||||
import com.tencent.supersonic.headless.core.pojo.ConnectInfo;
|
import com.tencent.supersonic.headless.core.pojo.ConnectInfo;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
@@ -54,7 +55,8 @@ public class H2Adaptor extends BaseDbAdaptor {
|
|||||||
String columnName = columns.getString("COLUMN_NAME");
|
String columnName = columns.getString("COLUMN_NAME");
|
||||||
String dataType = columns.getString("TYPE_NAME");
|
String dataType = columns.getString("TYPE_NAME");
|
||||||
String remarks = columns.getString("REMARKS");
|
String remarks = columns.getString("REMARKS");
|
||||||
dbColumns.add(new DBColumn(columnName, dataType, remarks));
|
FieldType fieldType = classifyColumnType(dataType);
|
||||||
|
dbColumns.add(new DBColumn(columnName, dataType, remarks, fieldType));
|
||||||
}
|
}
|
||||||
return dbColumns;
|
return dbColumns;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
|
|||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||||
import com.tencent.supersonic.headless.api.pojo.DBColumn;
|
import com.tencent.supersonic.headless.api.pojo.DBColumn;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
|
||||||
import com.tencent.supersonic.headless.core.pojo.ConnectInfo;
|
import com.tencent.supersonic.headless.core.pojo.ConnectInfo;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import net.sf.jsqlparser.expression.StringValue;
|
import net.sf.jsqlparser.expression.StringValue;
|
||||||
@@ -105,8 +106,41 @@ public class PostgresqlAdaptor extends BaseDbAdaptor {
|
|||||||
String columnName = columns.getString("COLUMN_NAME");
|
String columnName = columns.getString("COLUMN_NAME");
|
||||||
String dataType = columns.getString("TYPE_NAME");
|
String dataType = columns.getString("TYPE_NAME");
|
||||||
String remarks = columns.getString("REMARKS");
|
String remarks = columns.getString("REMARKS");
|
||||||
dbColumns.add(new DBColumn(columnName, dataType, remarks));
|
FieldType fieldType = classifyColumnType(dataType);
|
||||||
|
dbColumns.add(new DBColumn(columnName, dataType, remarks, fieldType));
|
||||||
}
|
}
|
||||||
return dbColumns;
|
return dbColumns;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected static FieldType classifyColumnType(String typeName) {
|
||||||
|
switch (typeName.toUpperCase()) {
|
||||||
|
case "INT":
|
||||||
|
case "INTEGER":
|
||||||
|
case "BIGINT":
|
||||||
|
case "SMALLINT":
|
||||||
|
case "SERIAL":
|
||||||
|
case "BIGSERIAL":
|
||||||
|
case "SMALLSERIAL":
|
||||||
|
case "REAL":
|
||||||
|
case "DOUBLE PRECISION":
|
||||||
|
case "NUMERIC":
|
||||||
|
case "DECIMAL":
|
||||||
|
return FieldType.measure;
|
||||||
|
case "DATE":
|
||||||
|
case "TIME":
|
||||||
|
case "TIMESTAMP":
|
||||||
|
case "TIMESTAMPTZ":
|
||||||
|
case "INTERVAL":
|
||||||
|
return FieldType.time;
|
||||||
|
case "VARCHAR":
|
||||||
|
case "CHAR":
|
||||||
|
case "TEXT":
|
||||||
|
case "CHARACTER VARYING":
|
||||||
|
case "CHARACTER":
|
||||||
|
case "UUID":
|
||||||
|
default:
|
||||||
|
return FieldType.dimension;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ public class LLMSemanticModeller implements SemanticModeller {
|
|||||||
+ "\n2. Create a Chinese name for the field and categorize the field into one of the following five types:"
|
+ "\n2. Create a Chinese name for the field and categorize the field into one of the following five types:"
|
||||||
+ "\n primary_key: This is a unique identifier for a record row in a database."
|
+ "\n primary_key: This is a unique identifier for a record row in a database."
|
||||||
+ "\n foreign_key: This is a key in a database whose value is derived from the primary key of another table."
|
+ "\n foreign_key: This is a key in a database whose value is derived from the primary key of another table."
|
||||||
+ "\n data_time: This represents the time when data is generated in the data warehouse."
|
+ "\n partition_time: This represents the time when data is generated in the data warehouse."
|
||||||
+ "\n dimension: Usually a string type, used for grouping and filtering data. No need to generate aggregate functions"
|
+ "\n dimension: Usually a string type, used for grouping and filtering data. No need to generate aggregate functions"
|
||||||
+ "\n measure: Usually a numeric type, used to quantify data from a certain evaluative perspective. "
|
+ "\n measure: Usually a numeric type, used to quantify data from a certain evaluative perspective. "
|
||||||
+ " Also, you need to generate aggregate functions(Eg: MAX, MIN, AVG, SUM, COUNT) for the measure type. "
|
+ " Also, you need to generate aggregate functions(Eg: MAX, MIN, AVG, SUM, COUNT) for the measure type. "
|
||||||
@@ -66,22 +66,23 @@ public class LLMSemanticModeller implements SemanticModeller {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ModelSchema build(DbSchema dbSchema, List<DbSchema> dbSchemas,
|
public void build(DbSchema dbSchema, List<DbSchema> dbSchemas, ModelSchema modelSchema,
|
||||||
ModelBuildReq modelBuildReq) {
|
ModelBuildReq modelBuildReq) {
|
||||||
|
if (!modelBuildReq.isBuildByLLM()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
Optional<ChatApp> chatApp = ChatAppManager.getApp(APP_KEY);
|
Optional<ChatApp> chatApp = ChatAppManager.getApp(APP_KEY);
|
||||||
if (!chatApp.isPresent() || !chatApp.get().isEnable()) {
|
if (!chatApp.isPresent() || !chatApp.get().isEnable()) {
|
||||||
return null;
|
return;
|
||||||
}
|
}
|
||||||
List<DbSchema> otherDbSchema = getOtherDbSchema(dbSchema, dbSchemas);
|
List<DbSchema> otherDbSchema = getOtherDbSchema(dbSchema, dbSchemas);
|
||||||
ModelSchemaExtractor extractor =
|
ModelSchemaExtractor extractor =
|
||||||
AiServices.create(ModelSchemaExtractor.class, getChatModel(modelBuildReq));
|
AiServices.create(ModelSchemaExtractor.class, getChatModel(modelBuildReq));
|
||||||
Prompt prompt = generatePrompt(dbSchema, otherDbSchema, chatApp.get());
|
Prompt prompt = generatePrompt(dbSchema, otherDbSchema, chatApp.get());
|
||||||
ModelSchema modelSchema =
|
modelSchema = extractor.generateModelSchema(prompt.toUserMessage().singleText());
|
||||||
extractor.generateModelSchema(prompt.toUserMessage().singleText());
|
|
||||||
log.info("dbSchema: {}\n otherRelatedDBSchema:{}\n modelSchema: {}",
|
log.info("dbSchema: {}\n otherRelatedDBSchema:{}\n modelSchema: {}",
|
||||||
JsonUtil.toString(dbSchema), JsonUtil.toString(otherDbSchema),
|
JsonUtil.toString(dbSchema), JsonUtil.toString(otherDbSchema),
|
||||||
JsonUtil.toString(modelSchema));
|
JsonUtil.toString(modelSchema));
|
||||||
return modelSchema;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<DbSchema> getOtherDbSchema(DbSchema curSchema, List<DbSchema> dbSchemas) {
|
private List<DbSchema> getOtherDbSchema(DbSchema curSchema, List<DbSchema> dbSchemas) {
|
||||||
@@ -113,7 +114,7 @@ public class LLMSemanticModeller implements SemanticModeller {
|
|||||||
Environment environment = ContextUtils.getBean(Environment.class);
|
Environment environment = ContextUtils.getBean(Environment.class);
|
||||||
String enableExemplarLoading =
|
String enableExemplarLoading =
|
||||||
environment.getProperty("s2.model.building.exemplars.enabled");
|
environment.getProperty("s2.model.building.exemplars.enabled");
|
||||||
if (Boolean.TRUE.equals(Boolean.parseBoolean(enableExemplarLoading))) {
|
if (Boolean.FALSE.equals(Boolean.parseBoolean(enableExemplarLoading))) {
|
||||||
log.info("Not enable load model-building exemplars");
|
log.info("Not enable load model-building exemplars");
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,13 +14,11 @@ import java.util.stream.Collectors;
|
|||||||
public class RuleSemanticModeller implements SemanticModeller {
|
public class RuleSemanticModeller implements SemanticModeller {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ModelSchema build(DbSchema dbSchema, List<DbSchema> dbSchemas,
|
public void build(DbSchema dbSchema, List<DbSchema> dbSchemas, ModelSchema modelSchema,
|
||||||
ModelBuildReq modelBuildReq) {
|
ModelBuildReq modelBuildReq) {
|
||||||
ModelSchema modelSchema = new ModelSchema();
|
|
||||||
List<ColumnSchema> columnSchemas =
|
List<ColumnSchema> columnSchemas =
|
||||||
dbSchema.getDbColumns().stream().map(this::convert).collect(Collectors.toList());
|
dbSchema.getDbColumns().stream().map(this::convert).collect(Collectors.toList());
|
||||||
modelSchema.setColumnSchemas(columnSchemas);
|
modelSchema.setColumnSchemas(columnSchemas);
|
||||||
return modelSchema;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private ColumnSchema convert(DBColumn dbColumn) {
|
private ColumnSchema convert(DBColumn dbColumn) {
|
||||||
@@ -29,6 +27,7 @@ public class RuleSemanticModeller implements SemanticModeller {
|
|||||||
columnSchema.setColumnName(dbColumn.getColumnName());
|
columnSchema.setColumnName(dbColumn.getColumnName());
|
||||||
columnSchema.setComment(dbColumn.getComment());
|
columnSchema.setComment(dbColumn.getComment());
|
||||||
columnSchema.setDataType(dbColumn.getDataType());
|
columnSchema.setDataType(dbColumn.getDataType());
|
||||||
|
columnSchema.setFiledType(dbColumn.getFieldType());
|
||||||
return columnSchema;
|
return columnSchema;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import java.util.List;
|
|||||||
|
|
||||||
public interface SemanticModeller {
|
public interface SemanticModeller {
|
||||||
|
|
||||||
ModelSchema build(DbSchema dbSchema, List<DbSchema> otherDbSchema, ModelBuildReq modelBuildReq);
|
void build(DbSchema dbSchema, List<DbSchema> otherDbSchema, ModelSchema modelSchema,
|
||||||
|
ModelBuildReq modelBuildReq);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -222,8 +222,11 @@ public class ModelServiceImpl implements ModelService {
|
|||||||
|
|
||||||
private void doBuild(ModelBuildReq modelBuildReq, DbSchema curSchema, List<DbSchema> dbSchemas,
|
private void doBuild(ModelBuildReq modelBuildReq, DbSchema curSchema, List<DbSchema> dbSchemas,
|
||||||
Map<String, ModelSchema> modelSchemaMap) {
|
Map<String, ModelSchema> modelSchemaMap) {
|
||||||
SemanticModeller semanticModeller = CoreComponentFactory.getSemanticModeller();
|
ModelSchema modelSchema = new ModelSchema();
|
||||||
ModelSchema modelSchema = semanticModeller.build(curSchema, dbSchemas, modelBuildReq);
|
List<SemanticModeller> semanticModellers = CoreComponentFactory.getSemanticModellers();
|
||||||
|
for (SemanticModeller semanticModeller : semanticModellers) {
|
||||||
|
semanticModeller.build(curSchema, dbSchemas, modelSchema, modelBuildReq);
|
||||||
|
}
|
||||||
modelSchemaMap.put(curSchema.getTable(), modelSchema);
|
modelSchemaMap.put(curSchema.getTable(), modelSchema);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,15 +4,26 @@ import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
|
|||||||
import com.tencent.supersonic.headless.server.modeller.SemanticModeller;
|
import com.tencent.supersonic.headless.server.modeller.SemanticModeller;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* QueryConverter QueryOptimizer QueryExecutor object factory
|
* QueryConverter QueryOptimizer QueryExecutor object factory
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class CoreComponentFactory extends ComponentFactory {
|
public class CoreComponentFactory extends ComponentFactory {
|
||||||
|
|
||||||
private static SemanticModeller semanticModeller;
|
private static List<SemanticModeller> semanticModellers = new ArrayList<>();
|
||||||
|
|
||||||
public static SemanticModeller getSemanticModeller() {
|
public static List<SemanticModeller> getSemanticModellers() {
|
||||||
return semanticModeller == null ? init(SemanticModeller.class) : semanticModeller;
|
if (semanticModellers.isEmpty()) {
|
||||||
|
initSemanticModellers();
|
||||||
}
|
}
|
||||||
|
return semanticModellers;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void initSemanticModellers() {
|
||||||
|
init(SemanticModeller.class, semanticModellers);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ com.tencent.supersonic.headless.core.cache.QueryCache=\
|
|||||||
### headless-server SPIs
|
### headless-server SPIs
|
||||||
|
|
||||||
com.tencent.supersonic.headless.server.modeller.SemanticModeller=\
|
com.tencent.supersonic.headless.server.modeller.SemanticModeller=\
|
||||||
|
com.tencent.supersonic.headless.server.modeller.RuleSemanticModeller, \
|
||||||
com.tencent.supersonic.headless.server.modeller.LLMSemanticModeller
|
com.tencent.supersonic.headless.server.modeller.LLMSemanticModeller
|
||||||
|
|
||||||
### chat-server SPIs
|
### chat-server SPIs
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import java.util.Map;
|
|||||||
|
|
||||||
@Disabled
|
@Disabled
|
||||||
@TestPropertySource(properties = {"s2.model.building.exemplars.enabled = false"})
|
@TestPropertySource(properties = {"s2.model.building.exemplars.enabled = false"})
|
||||||
public class LLMSemanticModellerTest extends BaseTest {
|
public class SemanticModellerTest extends BaseTest {
|
||||||
|
|
||||||
private LLMConfigUtils.LLMType llmType = LLMConfigUtils.LLMType.OLLAMA_LLAMA3;
|
private LLMConfigUtils.LLMType llmType = LLMConfigUtils.LLMType.OLLAMA_LLAMA3;
|
||||||
|
|
||||||
@@ -49,7 +49,7 @@ public class LLMSemanticModellerTest extends BaseTest {
|
|||||||
Assertions.assertEquals(4, stayTimeModelSchema.getColumnSchemas().size());
|
Assertions.assertEquals(4, stayTimeModelSchema.getColumnSchemas().size());
|
||||||
Assertions.assertEquals(FieldType.foreign_key,
|
Assertions.assertEquals(FieldType.foreign_key,
|
||||||
stayTimeModelSchema.getColumnByName("user_name").getFiledType());
|
stayTimeModelSchema.getColumnByName("user_name").getFiledType());
|
||||||
Assertions.assertEquals(FieldType.data_time,
|
Assertions.assertEquals(FieldType.partition_time,
|
||||||
stayTimeModelSchema.getColumnByName("imp_date").getFiledType());
|
stayTimeModelSchema.getColumnByName("imp_date").getFiledType());
|
||||||
Assertions.assertEquals(FieldType.dimension,
|
Assertions.assertEquals(FieldType.dimension,
|
||||||
stayTimeModelSchema.getColumnByName("page").getFiledType());
|
stayTimeModelSchema.getColumnByName("page").getFiledType());
|
||||||
@@ -73,7 +73,7 @@ public class LLMSemanticModellerTest extends BaseTest {
|
|||||||
|
|
||||||
ModelSchema pvModelSchema = modelSchemaMap.values().iterator().next();
|
ModelSchema pvModelSchema = modelSchemaMap.values().iterator().next();
|
||||||
Assertions.assertEquals(5, pvModelSchema.getColumnSchemas().size());
|
Assertions.assertEquals(5, pvModelSchema.getColumnSchemas().size());
|
||||||
Assertions.assertEquals(FieldType.data_time,
|
Assertions.assertEquals(FieldType.partition_time,
|
||||||
pvModelSchema.getColumnByName("imp_date").getFiledType());
|
pvModelSchema.getColumnByName("imp_date").getFiledType());
|
||||||
Assertions.assertEquals(FieldType.dimension,
|
Assertions.assertEquals(FieldType.dimension,
|
||||||
pvModelSchema.getColumnByName("user_name").getFiledType());
|
pvModelSchema.getColumnByName("user_name").getFiledType());
|
||||||
Reference in New Issue
Block a user