(improvement)(Headless) Support building data-model by LLM #1319 (#1784)

Co-authored-by: lxwcodemonkey
This commit is contained in:
LXW
2024-10-11 14:34:17 +08:00
committed by GitHub
parent 50b0036d0f
commit fbf0ea0627
15 changed files with 224 additions and 15 deletions

View File

@@ -0,0 +1,21 @@
package com.tencent.supersonic.headless.api.pojo;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.List;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class DbSchema {
private String db;
private String table;
private String sql;
private List<DBColumn> dbColumns;
}

View File

@@ -0,0 +1,18 @@
package com.tencent.supersonic.headless.api.pojo;
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
import lombok.Data;
@Data
public class FieldSchema {
private String columnName;
private String dataType;
private String comment;
private FieldType filedType;
private String name;
}

View File

@@ -0,0 +1,17 @@
package com.tencent.supersonic.headless.api.pojo;
import java.util.List;
import lombok.Data;
@Data
public class ModelSchema {
private String name;
private String bizName;
private String description;
private List<FieldSchema> filedSchemas;
}

View File

@@ -0,0 +1,5 @@
package com.tencent.supersonic.headless.api.pojo.enums;
public enum FieldType {
primary_key, foreign_key, data_time, dimension, measure;
}

View File

@@ -1,7 +1,5 @@
package com.tencent.supersonic.headless.api.pojo.request;
import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.api.pojo.Dim;
import com.tencent.supersonic.headless.api.pojo.DrillDownDimension;
import com.tencent.supersonic.headless.api.pojo.ModelDetail;
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
@@ -41,13 +39,6 @@ public class ModelReq extends SchemaItem {
private Map<String, Object> ext;
public List<Dim> getTimeDimension() {
if (modelDetail == null) {
return Lists.newArrayList();
}
return modelDetail.filterTimeDims();
}
public String getViewer() {
if (viewers == null) {
return null;

View File

@@ -3,7 +3,15 @@ package com.tencent.supersonic.headless.api.pojo.request;
import lombok.Data;
@Data
public class ColumnReq {
public class ModelSchemaReq {
private Long databaseId;
private String sql;
private String db;
private String table;
private boolean buildByLLM;
}

View File

@@ -0,0 +1,11 @@
package com.tencent.supersonic.headless.server.builder;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.provider.ModelProvider;
public abstract class IntelligentBuilder {
protected ChatLanguageModel getChatModel() {
return ModelProvider.getChatModel();
}
}

View File

@@ -0,0 +1,58 @@
package com.tencent.supersonic.headless.server.builder;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.DbSchema;
import com.tencent.supersonic.headless.api.pojo.ModelSchema;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.service.AiServices;
import org.springframework.stereotype.Component;
import java.util.HashMap;
import java.util.Map;
@Component
public class ModelIntelligentBuilder extends IntelligentBuilder {
public static final String INSTRUCTION = ""
+ "Role: As an experienced data analyst with extensive modeling experience, "
+ " you are expected to have a deep understanding of data analysis and data modeling concepts."
+ "\nJob: You will be given a database table structure, which includes the database table name, field name,"
+ " field type, and field comments. Your task is to utilize this information for data modeling."
+ "\nTask:"
+ "\n1. Generate a name and description for the model. Please note, 'bizName' refers to the English name, while 'name' is the Chinese name."
+ "\n2. Create a Chinese name for the field and categorize the field into one of the following five types:"
+ "\n primary_key: This is a unique identifier for a record row in a database."
+ "\n foreign_key: This is a key in a database whose value is derived from the primary key of another table."
+ "\n data_time: This represents the time when data is generated in the data warehouse."
+ "\n dimension: Usually a string type, used for grouping and filtering data."
+ "\n measure: Usually a numeric type, used to quantify data from a certain evaluative perspective."
+ "\nDBSchema: {{DBSchema}}" + "\nExemplar: {{exemplar}}";
interface ModelSchemaExtractor {
ModelSchema generateModelSchema(String text);
}
public ModelSchema build(DbSchema dbSchema) {
ChatLanguageModel chatModel = getChatModel();
ModelSchemaExtractor extractor = AiServices.create(ModelSchemaExtractor.class, chatModel);
Prompt prompt = generatePrompt(dbSchema);
return extractor.generateModelSchema(prompt.toUserMessage().singleText());
}
private Prompt generatePrompt(DbSchema dbSchema) {
Map<String, Object> variable = new HashMap<>();
variable.put("exemplar", loadExemplars());
variable.put("DBSchema", JsonUtil.toString(dbSchema));
return PromptTemplate.from(INSTRUCTION).apply(variable);
}
private String loadExemplars() {
// to add
return "";
}
}

View File

@@ -6,8 +6,8 @@ import javax.servlet.http.HttpServletResponse;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.headless.api.pojo.DBColumn;
import com.tencent.supersonic.headless.api.pojo.request.ColumnReq;
import com.tencent.supersonic.headless.api.pojo.request.DatabaseReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaReq;
import com.tencent.supersonic.headless.api.pojo.request.SqlExecuteReq;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
@@ -96,8 +96,9 @@ public class DatabaseController {
}
@PostMapping("/listColumnsBySql")
public List<DBColumn> listColumnsBySql(@RequestBody ColumnReq columnReq) throws SQLException {
return databaseService.getColumns(columnReq.getDatabaseId(), columnReq.getSql());
public List<DBColumn> listColumnsBySql(@RequestBody ModelSchemaReq modelSchemaReq)
throws SQLException {
return databaseService.getColumns(modelSchemaReq.getDatabaseId(), modelSchemaReq.getSql());
}
@GetMapping("/getDatabaseParameters")

View File

@@ -7,9 +7,11 @@ import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.common.pojo.enums.AuthType;
import com.tencent.supersonic.headless.api.pojo.ModelSchema;
import com.tencent.supersonic.headless.api.pojo.request.FieldRemovedReq;
import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaReq;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import com.tencent.supersonic.headless.api.pojo.response.UnAvailableItemResp;
@@ -24,6 +26,7 @@ import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
@@ -105,4 +108,10 @@ public class ModelController {
public UnAvailableItemResp getUnAvailableItem(@RequestBody FieldRemovedReq fieldRemovedReq) {
return modelService.getUnAvailableItem(fieldRemovedReq);
}
@PostMapping("/buildModelSchema")
public ModelSchema buildModelSchema(@RequestBody ModelSchemaReq modelSchemaReq)
throws SQLException {
return modelService.buildModelSchema(modelSchemaReq);
}
}

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.server.service;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.DBColumn;
import com.tencent.supersonic.headless.api.pojo.request.DatabaseReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaReq;
import com.tencent.supersonic.headless.api.pojo.request.SqlExecuteReq;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
@@ -36,6 +37,8 @@ public interface DatabaseService {
List<String> getTables(Long id, String db) throws SQLException;
List<DBColumn> getDbColumns(ModelSchemaReq modelSchemaReq) throws SQLException;
List<DBColumn> getColumns(Long id, String db, String table) throws SQLException;
List<DBColumn> getColumns(Long id, String sql) throws SQLException;

View File

@@ -5,14 +5,17 @@ import com.tencent.supersonic.common.pojo.ItemDateResp;
import com.tencent.supersonic.common.pojo.enums.AuthType;
import com.tencent.supersonic.headless.api.pojo.ItemDateFilter;
import com.tencent.supersonic.headless.api.pojo.MetaFilter;
import com.tencent.supersonic.headless.api.pojo.ModelSchema;
import com.tencent.supersonic.headless.api.pojo.request.FieldRemovedReq;
import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaReq;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import com.tencent.supersonic.headless.api.pojo.response.UnAvailableItemResp;
import com.tencent.supersonic.headless.server.pojo.ModelFilter;
import java.sql.SQLException;
import java.util.List;
import java.util.Map;
@@ -32,6 +35,8 @@ public interface ModelService {
UnAvailableItemResp getUnAvailableItem(FieldRemovedReq fieldRemovedReq);
ModelSchema buildModelSchema(ModelSchemaReq modelSchemaReq) throws SQLException;
List<ModelResp> getModelListWithAuth(User user, Long domainId, AuthType authType);
List<ModelResp> getModelAuthList(User user, Long domainId, AuthType authTypeEnum);

View File

@@ -7,6 +7,7 @@ import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.pojo.enums.EngineType;
import com.tencent.supersonic.headless.api.pojo.DBColumn;
import com.tencent.supersonic.headless.api.pojo.request.DatabaseReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaReq;
import com.tencent.supersonic.headless.api.pojo.request.SqlExecuteReq;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
@@ -205,6 +206,16 @@ public class DatabaseServiceImpl extends ServiceImpl<DatabaseDOMapper, DatabaseD
return dbAdaptor.getTables(DatabaseConverter.getConnectInfo(databaseResp), db);
}
@Override
public List<DBColumn> getDbColumns(ModelSchemaReq modelSchemaReq) throws SQLException {
if (StringUtils.isNotBlank(modelSchemaReq.getSql())) {
return getColumns(modelSchemaReq.getDatabaseId(), modelSchemaReq.getSql());
} else {
return getColumns(modelSchemaReq.getDatabaseId(), modelSchemaReq.getDb(),
modelSchemaReq.getTable());
}
}
@Override
public List<DBColumn> getColumns(Long id, String db, String table) throws SQLException {
DatabaseResp databaseResp = getDatabase(id);

View File

@@ -9,17 +9,22 @@ import com.tencent.supersonic.common.pojo.enums.EventType;
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.DBColumn;
import com.tencent.supersonic.headless.api.pojo.DbSchema;
import com.tencent.supersonic.headless.api.pojo.Dim;
import com.tencent.supersonic.headless.api.pojo.FieldSchema;
import com.tencent.supersonic.headless.api.pojo.Identify;
import com.tencent.supersonic.headless.api.pojo.ItemDateFilter;
import com.tencent.supersonic.headless.api.pojo.Measure;
import com.tencent.supersonic.headless.api.pojo.MetaFilter;
import com.tencent.supersonic.headless.api.pojo.ModelSchema;
import com.tencent.supersonic.headless.api.pojo.request.DateInfoReq;
import com.tencent.supersonic.headless.api.pojo.request.DimensionReq;
import com.tencent.supersonic.headless.api.pojo.request.FieldRemovedReq;
import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq;
import com.tencent.supersonic.headless.api.pojo.request.MetricReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaReq;
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
@@ -27,6 +32,7 @@ import com.tencent.supersonic.headless.api.pojo.response.DomainResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import com.tencent.supersonic.headless.api.pojo.response.UnAvailableItemResp;
import com.tencent.supersonic.headless.server.builder.ModelIntelligentBuilder;
import com.tencent.supersonic.headless.server.persistence.dataobject.DateInfoDO;
import com.tencent.supersonic.headless.server.persistence.dataobject.ModelDO;
import com.tencent.supersonic.headless.server.persistence.repository.DateInfoRepository;
@@ -48,6 +54,7 @@ import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.CollectionUtils;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Date;
@@ -78,10 +85,13 @@ public class ModelServiceImpl implements ModelService {
private DateInfoRepository dateInfoRepository;
private ModelIntelligentBuilder modelIntelligentBuilder;
public ModelServiceImpl(ModelRepository modelRepository, DatabaseService databaseService,
@Lazy DimensionService dimensionService, @Lazy MetricService metricService,
DomainService domainService, UserService userService, DataSetService dataSetService,
DateInfoRepository dateInfoRepository) {
DateInfoRepository dateInfoRepository,
ModelIntelligentBuilder modelIntelligentBuilder) {
this.modelRepository = modelRepository;
this.databaseService = databaseService;
this.dimensionService = dimensionService;
@@ -90,6 +100,7 @@ public class ModelServiceImpl implements ModelService {
this.userService = userService;
this.dataSetService = dataSetService;
this.dateInfoRepository = dateInfoRepository;
this.modelIntelligentBuilder = modelIntelligentBuilder;
}
@Override
@@ -184,6 +195,42 @@ public class ModelServiceImpl implements ModelService {
.build();
}
@Override
public ModelSchema buildModelSchema(ModelSchemaReq modelSchemaReq) throws SQLException {
List<DBColumn> dbColumns = databaseService.getDbColumns(modelSchemaReq);
if (modelSchemaReq.isBuildByLLM()) {
DbSchema dbSchema = convert(modelSchemaReq, dbColumns);
return modelIntelligentBuilder.build(dbSchema);
}
return build(dbColumns);
}
private DbSchema convert(ModelSchemaReq modelSchemaReq, List<DBColumn> dbColumns) {
DbSchema dbSchema = new DbSchema();
dbSchema.setDb(modelSchemaReq.getDb());
dbSchema.setTable(modelSchemaReq.getTable());
dbSchema.setSql(modelSchemaReq.getSql());
dbSchema.setDbColumns(dbColumns);
return dbSchema;
}
private FieldSchema convert(DBColumn dbColumn) {
FieldSchema fieldSchema = new FieldSchema();
fieldSchema.setName(dbColumn.getComment());
fieldSchema.setColumnName(dbColumn.getColumnName());
fieldSchema.setComment(dbColumn.getComment());
fieldSchema.setDataType(dbColumn.getDataType());
return fieldSchema;
}
private ModelSchema build(List<DBColumn> dbColumns) {
ModelSchema modelSchema = new ModelSchema();
List<FieldSchema> fieldSchemas =
dbColumns.stream().map(this::convert).collect(Collectors.toList());
modelSchema.setFiledSchemas(fieldSchemas);
return modelSchema;
}
private void batchCreateDimension(ModelDO modelDO, User user) throws Exception {
List<DimensionReq> dimensionReqs = ModelConverter.convertDimensionList(modelDO);
dimensionService.createDimensionBatch(dimensionReqs, user);

View File

@@ -14,6 +14,7 @@ import com.tencent.supersonic.headless.api.pojo.enums.DimensionType;
import com.tencent.supersonic.headless.api.pojo.enums.IdentifyType;
import com.tencent.supersonic.headless.api.pojo.request.ModelReq;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import com.tencent.supersonic.headless.server.builder.ModelIntelligentBuilder;
import com.tencent.supersonic.headless.server.persistence.dataobject.ModelDO;
import com.tencent.supersonic.headless.server.persistence.repository.DateInfoRepository;
import com.tencent.supersonic.headless.server.persistence.repository.ModelRepository;
@@ -75,8 +76,11 @@ class ModelServiceImplTest {
UserService userService = Mockito.mock(UserService.class);
DateInfoRepository dateInfoRepository = Mockito.mock(DateInfoRepository.class);
DataSetService viewService = Mockito.mock(DataSetService.class);
ModelIntelligentBuilder modelIntelligentBuilder =
Mockito.mock(ModelIntelligentBuilder.class);
return new ModelServiceImpl(modelRepository, databaseService, dimensionService,
metricService, domainService, userService, viewService, dateInfoRepository);
metricService, domainService, userService, viewService, dateInfoRepository,
modelIntelligentBuilder);
}
private ModelReq mockModelReq() {