[improvement][Headless] Supports automatic batch creation of models based on db table names.

This commit is contained in:
lxwcodemonkey
2024-11-30 22:22:56 +08:00
parent 8299084c95
commit 6ecc5a9362
6 changed files with 82 additions and 1 deletions

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.headless.api.pojo.request;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.headless.api.pojo.DbSchema;
import lombok.Data;
import java.util.List;
@@ -10,12 +11,16 @@ public class ModelBuildReq {
private Long databaseId;
private Long domainId;
private String sql;
private String db;
private List<String> tables;
private List<DbSchema> dbSchemas;
private boolean buildByLLM;
private Integer chatModelId;

View File

@@ -50,6 +50,14 @@ public class ModelController {
return true;
}
@PostMapping("/createModelBatch")
public Boolean createModelBatch(@RequestBody ModelBuildReq modelBuildReq, HttpServletRequest request,
HttpServletResponse response) throws Exception {
User user = UserHolder.findUser(request, response);
modelService.createModel(modelBuildReq, user);
return true;
}
@PostMapping("/updateModel")
public Boolean updateModel(@RequestBody ModelReq modelReq, HttpServletRequest request,
HttpServletResponse response) throws Exception {

View File

@@ -23,6 +23,8 @@ public interface ModelService {
ModelResp createModel(ModelReq datasourceReq, User user) throws Exception;
List<ModelResp> createModel(ModelBuildReq modelBuildReq, User user) throws Exception;
ModelResp updateModel(ModelReq datasourceReq, User user) throws Exception;
List<ModelResp> getModelList(MetaFilter metaFilter);

View File

@@ -138,7 +138,8 @@ public class DatabaseServiceImpl extends ServiceImpl<DatabaseDOMapper, DatabaseD
QueryWrapper<DatabaseDO> queryWrapper = new QueryWrapper<>();
queryWrapper.lambda().eq(DatabaseDO::getType, dataType.getFeature());
List<DatabaseDO> list = list(queryWrapper);
return list.stream().map(DatabaseConverter::convertWithPassword).collect(Collectors.toList());
return list.stream().map(DatabaseConverter::convertWithPassword)
.collect(Collectors.toList());
}
@Override

View File

@@ -125,6 +125,19 @@ public class ModelServiceImpl implements ModelService {
return ModelConverter.convert(modelDO);
}
@Override
public List<ModelResp> createModel(ModelBuildReq modelBuildReq, User user) throws Exception {
List<ModelResp> modelResps = Lists.newArrayList();
Map<String, ModelSchema> modelSchemaMap = buildModelSchema(modelBuildReq);
for (Map.Entry<String, ModelSchema> entry : modelSchemaMap.entrySet()) {
ModelReq modelReq =
ModelConverter.convert(entry.getValue(), modelBuildReq, entry.getKey());
ModelResp modelResp = createModel(modelReq, user);
modelResps.add(modelResp);
}
return modelResps;
}
@Override
@Transactional
public ModelResp updateModel(ModelReq modelReq, User user) throws Exception {
@@ -231,6 +244,9 @@ public class ModelServiceImpl implements ModelService {
}
private List<DbSchema> getDbSchemes(ModelBuildReq modelBuildReq) throws SQLException {
if (!CollectionUtils.isEmpty(modelBuildReq.getDbSchemas())) {
return modelBuildReq.getDbSchemas();
}
Map<String, List<DBColumn>> dbColumnMap = databaseService.getDbColumns(modelBuildReq);
return convert(dbColumnMap, modelBuildReq);
}

View File

@@ -7,6 +7,7 @@ import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import com.tencent.supersonic.common.util.BeanMapper;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.ColumnSchema;
import com.tencent.supersonic.headless.api.pojo.Dim;
import com.tencent.supersonic.headless.api.pojo.DrillDownDimension;
import com.tencent.supersonic.headless.api.pojo.Identify;
@@ -14,11 +15,16 @@ import com.tencent.supersonic.headless.api.pojo.Measure;
import com.tencent.supersonic.headless.api.pojo.MeasureParam;
import com.tencent.supersonic.headless.api.pojo.MetricDefineByMeasureParams;
import com.tencent.supersonic.headless.api.pojo.ModelDetail;
import com.tencent.supersonic.headless.api.pojo.ModelSchema;
import com.tencent.supersonic.headless.api.pojo.enums.DimensionType;
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
import com.tencent.supersonic.headless.api.pojo.enums.IdentifyType;
import com.tencent.supersonic.headless.api.pojo.enums.MetricDefineType;
import com.tencent.supersonic.headless.api.pojo.enums.ModelDefineType;
import com.tencent.supersonic.headless.api.pojo.enums.SemanticType;
import com.tencent.supersonic.headless.api.pojo.request.DimensionReq;
import com.tencent.supersonic.headless.api.pojo.request.MetricReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelReq;
import com.tencent.supersonic.headless.api.pojo.response.DomainResp;
import com.tencent.supersonic.headless.api.pojo.response.MeasureResp;
@@ -156,6 +162,49 @@ public class ModelConverter {
return dimensionReq;
}
public static ModelReq convert(ModelSchema modelSchema, ModelBuildReq modelBuildReq,
String tableName) {
ModelReq modelReq = new ModelReq();
modelReq.setName(modelSchema.getName());
modelReq.setBizName(modelSchema.getBizName());
modelReq.setDatabaseId(modelBuildReq.getDatabaseId());
modelReq.setDomainId(modelBuildReq.getDomainId());
ModelDetail modelDetail = new ModelDetail();
if (StringUtils.isNotBlank(modelBuildReq.getSql())) {
modelDetail.setQueryType(ModelDefineType.SQL_QUERY.getName());
modelDetail.setSqlQuery(modelBuildReq.getSql());
} else {
modelDetail.setQueryType(ModelDefineType.TABLE_QUERY.getName());
modelDetail.setTableQuery(String.format("%s.%s", modelBuildReq.getDb(), tableName));
}
for (ColumnSchema columnSchema : modelSchema.getColumnSchemas()) {
FieldType fieldType = columnSchema.getFiledType();
if (getIdentifyType(fieldType) != null) {
Identify identify = new Identify(columnSchema.getName(),
getIdentifyType(fieldType).name(), columnSchema.getColumnName(), 1);
modelDetail.getIdentifiers().add(identify);
} else if (FieldType.measure.equals(fieldType)) {
Measure measure = new Measure(columnSchema.getName(), columnSchema.getColumnName(),
columnSchema.getAgg().getOperator(), 1);
modelDetail.getMeasures().add(measure);
} else {
Dim dim = new Dim(columnSchema.getName(), columnSchema.getColumnName(),
DimensionType.valueOf(columnSchema.getFiledType().name()), 1);
modelDetail.getDimensions().add(dim);
}
}
modelReq.setModelDetail(modelDetail);
return modelReq;
}
private static IdentifyType getIdentifyType(FieldType fieldType) {
if (FieldType.foreign_key.equals(fieldType) || FieldType.primary_key.equals(fieldType)) {
return IdentifyType.primary;
} else {
return IdentifyType.foreign;
}
}
public static List<ModelResp> convertList(List<ModelDO> modelDOS) {
List<ModelResp> modelDescs = Lists.newArrayList();
if (!CollectionUtils.isEmpty(modelDOS)) {