(improvement)(Headless) Refactor the SemanticModeller to rule first and then llm, and automatically infer field types in the rule method. (#1900)

Co-authored-by: lxwcodemonkey
This commit is contained in:
LXW
2024-11-11 00:10:58 +08:00
committed by GitHub
parent ea6a9ebc5f
commit 87729956e8
12 changed files with 101 additions and 23 deletions

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.headless.api.pojo; package com.tencent.supersonic.headless.api.pojo;
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
@@ -14,4 +15,6 @@ public class DBColumn {
private String dataType; private String dataType;
private String comment; private String comment;
private FieldType fieldType;
} }

View File

@@ -1,5 +1,5 @@
package com.tencent.supersonic.headless.api.pojo.enums; package com.tencent.supersonic.headless.api.pojo.enums;
public enum FieldType { public enum FieldType {
primary_key, foreign_key, data_time, dimension, measure; primary_key, foreign_key, partition_time, time, dimension, measure;
} }

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.core.adaptor.db;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.api.pojo.DBColumn; import com.tencent.supersonic.headless.api.pojo.DBColumn;
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
import com.tencent.supersonic.headless.core.pojo.ConnectInfo; import com.tencent.supersonic.headless.core.pojo.ConnectInfo;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@@ -71,7 +72,8 @@ public abstract class BaseDbAdaptor implements DbAdaptor {
String columnName = columns.getString("COLUMN_NAME"); String columnName = columns.getString("COLUMN_NAME");
String dataType = columns.getString("TYPE_NAME"); String dataType = columns.getString("TYPE_NAME");
String remarks = columns.getString("REMARKS"); String remarks = columns.getString("REMARKS");
dbColumns.add(new DBColumn(columnName, dataType, remarks)); FieldType fieldType = classifyColumnType(dataType);
dbColumns.add(new DBColumn(columnName, dataType, remarks, fieldType));
} }
return dbColumns; return dbColumns;
} }
@@ -82,4 +84,25 @@ public abstract class BaseDbAdaptor implements DbAdaptor {
return connection.getMetaData(); return connection.getMetaData();
} }
protected static FieldType classifyColumnType(String typeName) {
switch (typeName.toUpperCase()) {
case "INT":
case "INTEGER":
case "BIGINT":
case "SMALLINT":
case "TINYINT":
case "FLOAT":
case "DOUBLE":
case "DECIMAL":
case "NUMERIC":
return FieldType.measure;
case "DATE":
case "TIME":
case "TIMESTAMP":
return FieldType.time;
default:
return FieldType.dimension;
}
}
} }

View File

@@ -4,6 +4,7 @@ import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.headless.api.pojo.DBColumn; import com.tencent.supersonic.headless.api.pojo.DBColumn;
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
import com.tencent.supersonic.headless.core.pojo.ConnectInfo; import com.tencent.supersonic.headless.core.pojo.ConnectInfo;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@@ -54,7 +55,8 @@ public class H2Adaptor extends BaseDbAdaptor {
String columnName = columns.getString("COLUMN_NAME"); String columnName = columns.getString("COLUMN_NAME");
String dataType = columns.getString("TYPE_NAME"); String dataType = columns.getString("TYPE_NAME");
String remarks = columns.getString("REMARKS"); String remarks = columns.getString("REMARKS");
dbColumns.add(new DBColumn(columnName, dataType, remarks)); FieldType fieldType = classifyColumnType(dataType);
dbColumns.add(new DBColumn(columnName, dataType, remarks, fieldType));
} }
return dbColumns; return dbColumns;
} }

View File

@@ -5,6 +5,7 @@ import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.headless.api.pojo.DBColumn; import com.tencent.supersonic.headless.api.pojo.DBColumn;
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
import com.tencent.supersonic.headless.core.pojo.ConnectInfo; import com.tencent.supersonic.headless.core.pojo.ConnectInfo;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.StringValue; import net.sf.jsqlparser.expression.StringValue;
@@ -105,8 +106,41 @@ public class PostgresqlAdaptor extends BaseDbAdaptor {
String columnName = columns.getString("COLUMN_NAME"); String columnName = columns.getString("COLUMN_NAME");
String dataType = columns.getString("TYPE_NAME"); String dataType = columns.getString("TYPE_NAME");
String remarks = columns.getString("REMARKS"); String remarks = columns.getString("REMARKS");
dbColumns.add(new DBColumn(columnName, dataType, remarks)); FieldType fieldType = classifyColumnType(dataType);
dbColumns.add(new DBColumn(columnName, dataType, remarks, fieldType));
} }
return dbColumns; return dbColumns;
} }
protected static FieldType classifyColumnType(String typeName) {
switch (typeName.toUpperCase()) {
case "INT":
case "INTEGER":
case "BIGINT":
case "SMALLINT":
case "SERIAL":
case "BIGSERIAL":
case "SMALLSERIAL":
case "REAL":
case "DOUBLE PRECISION":
case "NUMERIC":
case "DECIMAL":
return FieldType.measure;
case "DATE":
case "TIME":
case "TIMESTAMP":
case "TIMESTAMPTZ":
case "INTERVAL":
return FieldType.time;
case "VARCHAR":
case "CHAR":
case "TEXT":
case "CHARACTER VARYING":
case "CHARACTER":
case "UUID":
default:
return FieldType.dimension;
}
}
} }

View File

@@ -45,7 +45,7 @@ public class LLMSemanticModeller implements SemanticModeller {
+ "\n2. Create a Chinese name for the field and categorize the field into one of the following five types:" + "\n2. Create a Chinese name for the field and categorize the field into one of the following five types:"
+ "\n primary_key: This is a unique identifier for a record row in a database." + "\n primary_key: This is a unique identifier for a record row in a database."
+ "\n foreign_key: This is a key in a database whose value is derived from the primary key of another table." + "\n foreign_key: This is a key in a database whose value is derived from the primary key of another table."
+ "\n data_time: This represents the time when data is generated in the data warehouse." + "\n partition_time: This represents the time when data is generated in the data warehouse."
+ "\n dimension: Usually a string type, used for grouping and filtering data. No need to generate aggregate functions" + "\n dimension: Usually a string type, used for grouping and filtering data. No need to generate aggregate functions"
+ "\n measure: Usually a numeric type, used to quantify data from a certain evaluative perspective. " + "\n measure: Usually a numeric type, used to quantify data from a certain evaluative perspective. "
+ " Also, you need to generate aggregate functions(Eg: MAX, MIN, AVG, SUM, COUNT) for the measure type. " + " Also, you need to generate aggregate functions(Eg: MAX, MIN, AVG, SUM, COUNT) for the measure type. "
@@ -66,22 +66,23 @@ public class LLMSemanticModeller implements SemanticModeller {
} }
@Override @Override
public ModelSchema build(DbSchema dbSchema, List<DbSchema> dbSchemas, public void build(DbSchema dbSchema, List<DbSchema> dbSchemas, ModelSchema modelSchema,
ModelBuildReq modelBuildReq) { ModelBuildReq modelBuildReq) {
if (!modelBuildReq.isBuildByLLM()) {
return;
}
Optional<ChatApp> chatApp = ChatAppManager.getApp(APP_KEY); Optional<ChatApp> chatApp = ChatAppManager.getApp(APP_KEY);
if (!chatApp.isPresent() || !chatApp.get().isEnable()) { if (!chatApp.isPresent() || !chatApp.get().isEnable()) {
return null; return;
} }
List<DbSchema> otherDbSchema = getOtherDbSchema(dbSchema, dbSchemas); List<DbSchema> otherDbSchema = getOtherDbSchema(dbSchema, dbSchemas);
ModelSchemaExtractor extractor = ModelSchemaExtractor extractor =
AiServices.create(ModelSchemaExtractor.class, getChatModel(modelBuildReq)); AiServices.create(ModelSchemaExtractor.class, getChatModel(modelBuildReq));
Prompt prompt = generatePrompt(dbSchema, otherDbSchema, chatApp.get()); Prompt prompt = generatePrompt(dbSchema, otherDbSchema, chatApp.get());
ModelSchema modelSchema = modelSchema = extractor.generateModelSchema(prompt.toUserMessage().singleText());
extractor.generateModelSchema(prompt.toUserMessage().singleText());
log.info("dbSchema: {}\n otherRelatedDBSchema:{}\n modelSchema: {}", log.info("dbSchema: {}\n otherRelatedDBSchema:{}\n modelSchema: {}",
JsonUtil.toString(dbSchema), JsonUtil.toString(otherDbSchema), JsonUtil.toString(dbSchema), JsonUtil.toString(otherDbSchema),
JsonUtil.toString(modelSchema)); JsonUtil.toString(modelSchema));
return modelSchema;
} }
private List<DbSchema> getOtherDbSchema(DbSchema curSchema, List<DbSchema> dbSchemas) { private List<DbSchema> getOtherDbSchema(DbSchema curSchema, List<DbSchema> dbSchemas) {
@@ -113,7 +114,7 @@ public class LLMSemanticModeller implements SemanticModeller {
Environment environment = ContextUtils.getBean(Environment.class); Environment environment = ContextUtils.getBean(Environment.class);
String enableExemplarLoading = String enableExemplarLoading =
environment.getProperty("s2.model.building.exemplars.enabled"); environment.getProperty("s2.model.building.exemplars.enabled");
if (Boolean.TRUE.equals(Boolean.parseBoolean(enableExemplarLoading))) { if (Boolean.FALSE.equals(Boolean.parseBoolean(enableExemplarLoading))) {
log.info("Not enable load model-building exemplars"); log.info("Not enable load model-building exemplars");
return ""; return "";
} }

View File

@@ -14,13 +14,11 @@ import java.util.stream.Collectors;
public class RuleSemanticModeller implements SemanticModeller { public class RuleSemanticModeller implements SemanticModeller {
@Override @Override
public ModelSchema build(DbSchema dbSchema, List<DbSchema> dbSchemas, public void build(DbSchema dbSchema, List<DbSchema> dbSchemas, ModelSchema modelSchema,
ModelBuildReq modelBuildReq) { ModelBuildReq modelBuildReq) {
ModelSchema modelSchema = new ModelSchema();
List<ColumnSchema> columnSchemas = List<ColumnSchema> columnSchemas =
dbSchema.getDbColumns().stream().map(this::convert).collect(Collectors.toList()); dbSchema.getDbColumns().stream().map(this::convert).collect(Collectors.toList());
modelSchema.setColumnSchemas(columnSchemas); modelSchema.setColumnSchemas(columnSchemas);
return modelSchema;
} }
private ColumnSchema convert(DBColumn dbColumn) { private ColumnSchema convert(DBColumn dbColumn) {
@@ -29,6 +27,7 @@ public class RuleSemanticModeller implements SemanticModeller {
columnSchema.setColumnName(dbColumn.getColumnName()); columnSchema.setColumnName(dbColumn.getColumnName());
columnSchema.setComment(dbColumn.getComment()); columnSchema.setComment(dbColumn.getComment());
columnSchema.setDataType(dbColumn.getDataType()); columnSchema.setDataType(dbColumn.getDataType());
columnSchema.setFiledType(dbColumn.getFieldType());
return columnSchema; return columnSchema;
} }

View File

@@ -9,6 +9,7 @@ import java.util.List;
public interface SemanticModeller { public interface SemanticModeller {
ModelSchema build(DbSchema dbSchema, List<DbSchema> otherDbSchema, ModelBuildReq modelBuildReq); void build(DbSchema dbSchema, List<DbSchema> otherDbSchema, ModelSchema modelSchema,
ModelBuildReq modelBuildReq);
} }

View File

@@ -222,8 +222,11 @@ public class ModelServiceImpl implements ModelService {
private void doBuild(ModelBuildReq modelBuildReq, DbSchema curSchema, List<DbSchema> dbSchemas, private void doBuild(ModelBuildReq modelBuildReq, DbSchema curSchema, List<DbSchema> dbSchemas,
Map<String, ModelSchema> modelSchemaMap) { Map<String, ModelSchema> modelSchemaMap) {
SemanticModeller semanticModeller = CoreComponentFactory.getSemanticModeller(); ModelSchema modelSchema = new ModelSchema();
ModelSchema modelSchema = semanticModeller.build(curSchema, dbSchemas, modelBuildReq); List<SemanticModeller> semanticModellers = CoreComponentFactory.getSemanticModellers();
for (SemanticModeller semanticModeller : semanticModellers) {
semanticModeller.build(curSchema, dbSchemas, modelSchema, modelBuildReq);
}
modelSchemaMap.put(curSchema.getTable(), modelSchema); modelSchemaMap.put(curSchema.getTable(), modelSchema);
} }

View File

@@ -4,15 +4,26 @@ import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
import com.tencent.supersonic.headless.server.modeller.SemanticModeller; import com.tencent.supersonic.headless.server.modeller.SemanticModeller;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import java.util.ArrayList;
import java.util.List;
/** /**
* QueryConverter QueryOptimizer QueryExecutor object factory * QueryConverter QueryOptimizer QueryExecutor object factory
*/ */
@Slf4j @Slf4j
public class CoreComponentFactory extends ComponentFactory { public class CoreComponentFactory extends ComponentFactory {
private static SemanticModeller semanticModeller; private static List<SemanticModeller> semanticModellers = new ArrayList<>();
public static SemanticModeller getSemanticModeller() { public static List<SemanticModeller> getSemanticModellers() {
return semanticModeller == null ? init(SemanticModeller.class) : semanticModeller; if (semanticModellers.isEmpty()) {
initSemanticModellers();
} }
return semanticModellers;
}
private static void initSemanticModellers() {
init(SemanticModeller.class, semanticModellers);
}
} }

View File

@@ -46,6 +46,7 @@ com.tencent.supersonic.headless.core.cache.QueryCache=\
### headless-server SPIs ### headless-server SPIs
com.tencent.supersonic.headless.server.modeller.SemanticModeller=\ com.tencent.supersonic.headless.server.modeller.SemanticModeller=\
com.tencent.supersonic.headless.server.modeller.RuleSemanticModeller, \
com.tencent.supersonic.headless.server.modeller.LLMSemanticModeller com.tencent.supersonic.headless.server.modeller.LLMSemanticModeller
### chat-server SPIs ### chat-server SPIs

View File

@@ -20,7 +20,7 @@ import java.util.Map;
@Disabled @Disabled
@TestPropertySource(properties = {"s2.model.building.exemplars.enabled = false"}) @TestPropertySource(properties = {"s2.model.building.exemplars.enabled = false"})
public class LLMSemanticModellerTest extends BaseTest { public class SemanticModellerTest extends BaseTest {
private LLMConfigUtils.LLMType llmType = LLMConfigUtils.LLMType.OLLAMA_LLAMA3; private LLMConfigUtils.LLMType llmType = LLMConfigUtils.LLMType.OLLAMA_LLAMA3;
@@ -49,7 +49,7 @@ public class LLMSemanticModellerTest extends BaseTest {
Assertions.assertEquals(4, stayTimeModelSchema.getColumnSchemas().size()); Assertions.assertEquals(4, stayTimeModelSchema.getColumnSchemas().size());
Assertions.assertEquals(FieldType.foreign_key, Assertions.assertEquals(FieldType.foreign_key,
stayTimeModelSchema.getColumnByName("user_name").getFiledType()); stayTimeModelSchema.getColumnByName("user_name").getFiledType());
Assertions.assertEquals(FieldType.data_time, Assertions.assertEquals(FieldType.partition_time,
stayTimeModelSchema.getColumnByName("imp_date").getFiledType()); stayTimeModelSchema.getColumnByName("imp_date").getFiledType());
Assertions.assertEquals(FieldType.dimension, Assertions.assertEquals(FieldType.dimension,
stayTimeModelSchema.getColumnByName("page").getFiledType()); stayTimeModelSchema.getColumnByName("page").getFiledType());
@@ -73,7 +73,7 @@ public class LLMSemanticModellerTest extends BaseTest {
ModelSchema pvModelSchema = modelSchemaMap.values().iterator().next(); ModelSchema pvModelSchema = modelSchemaMap.values().iterator().next();
Assertions.assertEquals(5, pvModelSchema.getColumnSchemas().size()); Assertions.assertEquals(5, pvModelSchema.getColumnSchemas().size());
Assertions.assertEquals(FieldType.data_time, Assertions.assertEquals(FieldType.partition_time,
pvModelSchema.getColumnByName("imp_date").getFiledType()); pvModelSchema.getColumnByName("imp_date").getFiledType());
Assertions.assertEquals(FieldType.dimension, Assertions.assertEquals(FieldType.dimension,
pvModelSchema.getColumnByName("user_name").getFiledType()); pvModelSchema.getColumnByName("user_name").getFiledType());