From 6ecc5a93624a87b1c0e49d3ecd3a969c997ac008 Mon Sep 17 00:00:00 2001 From: lxwcodemonkey Date: Sat, 30 Nov 2024 22:22:56 +0800 Subject: [PATCH] [improvement][Headless] Supports automatic batch creation of models based on db table names. --- .../api/pojo/request/ModelBuildReq.java | 5 ++ .../headless/server/rest/ModelController.java | 8 +++ .../headless/server/service/ModelService.java | 2 + .../service/impl/DatabaseServiceImpl.java | 3 +- .../server/service/impl/ModelServiceImpl.java | 16 ++++++ .../headless/server/utils/ModelConverter.java | 49 +++++++++++++++++++ 6 files changed, 82 insertions(+), 1 deletion(-) diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/ModelBuildReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/ModelBuildReq.java index 38299696a..9385bb28b 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/ModelBuildReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/ModelBuildReq.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.headless.api.pojo.request; import com.tencent.supersonic.common.pojo.ChatModelConfig; +import com.tencent.supersonic.headless.api.pojo.DbSchema; import lombok.Data; import java.util.List; @@ -10,12 +11,16 @@ public class ModelBuildReq { private Long databaseId; + private Long domainId; + private String sql; private String db; private List tables; + private List dbSchemas; + private boolean buildByLLM; private Integer chatModelId; diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/ModelController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/ModelController.java index 0605101c0..dd552636b 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/ModelController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/ModelController.java @@ -50,6 +50,14 @@ public class ModelController { return true; } + @PostMapping("/createModelBatch") + public Boolean createModelBatch(@RequestBody ModelBuildReq modelBuildReq, HttpServletRequest request, + HttpServletResponse response) throws Exception { + User user = UserHolder.findUser(request, response); + modelService.createModel(modelBuildReq, user); + return true; + } + @PostMapping("/updateModel") public Boolean updateModel(@RequestBody ModelReq modelReq, HttpServletRequest request, HttpServletResponse response) throws Exception { diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/ModelService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/ModelService.java index 203884b4d..471bf84d6 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/ModelService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/ModelService.java @@ -23,6 +23,8 @@ public interface ModelService { ModelResp createModel(ModelReq datasourceReq, User user) throws Exception; + List createModel(ModelBuildReq modelBuildReq, User user) throws Exception; + ModelResp updateModel(ModelReq datasourceReq, User user) throws Exception; List getModelList(MetaFilter metaFilter); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DatabaseServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DatabaseServiceImpl.java index 61a1d0906..976d90fbd 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DatabaseServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DatabaseServiceImpl.java @@ -138,7 +138,8 @@ public class DatabaseServiceImpl extends ServiceImpl queryWrapper = new QueryWrapper<>(); queryWrapper.lambda().eq(DatabaseDO::getType, dataType.getFeature()); List list = list(queryWrapper); - return list.stream().map(DatabaseConverter::convertWithPassword).collect(Collectors.toList()); + return list.stream().map(DatabaseConverter::convertWithPassword) + .collect(Collectors.toList()); } @Override diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java index e1c64a445..a7da69a61 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java @@ -125,6 +125,19 @@ public class ModelServiceImpl implements ModelService { return ModelConverter.convert(modelDO); } + @Override + public List createModel(ModelBuildReq modelBuildReq, User user) throws Exception { + List modelResps = Lists.newArrayList(); + Map modelSchemaMap = buildModelSchema(modelBuildReq); + for (Map.Entry entry : modelSchemaMap.entrySet()) { + ModelReq modelReq = + ModelConverter.convert(entry.getValue(), modelBuildReq, entry.getKey()); + ModelResp modelResp = createModel(modelReq, user); + modelResps.add(modelResp); + } + return modelResps; + } + @Override @Transactional public ModelResp updateModel(ModelReq modelReq, User user) throws Exception { @@ -231,6 +244,9 @@ public class ModelServiceImpl implements ModelService { } private List getDbSchemes(ModelBuildReq modelBuildReq) throws SQLException { + if (!CollectionUtils.isEmpty(modelBuildReq.getDbSchemas())) { + return modelBuildReq.getDbSchemas(); + } Map> dbColumnMap = databaseService.getDbColumns(modelBuildReq); return convert(dbColumnMap, modelBuildReq); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ModelConverter.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ModelConverter.java index dea1cf85e..d004c9a3b 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ModelConverter.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ModelConverter.java @@ -7,6 +7,7 @@ import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.common.pojo.enums.StatusEnum; import com.tencent.supersonic.common.util.BeanMapper; import com.tencent.supersonic.common.util.JsonUtil; +import com.tencent.supersonic.headless.api.pojo.ColumnSchema; import com.tencent.supersonic.headless.api.pojo.Dim; import com.tencent.supersonic.headless.api.pojo.DrillDownDimension; import com.tencent.supersonic.headless.api.pojo.Identify; @@ -14,11 +15,16 @@ import com.tencent.supersonic.headless.api.pojo.Measure; import com.tencent.supersonic.headless.api.pojo.MeasureParam; import com.tencent.supersonic.headless.api.pojo.MetricDefineByMeasureParams; import com.tencent.supersonic.headless.api.pojo.ModelDetail; +import com.tencent.supersonic.headless.api.pojo.ModelSchema; import com.tencent.supersonic.headless.api.pojo.enums.DimensionType; +import com.tencent.supersonic.headless.api.pojo.enums.FieldType; +import com.tencent.supersonic.headless.api.pojo.enums.IdentifyType; import com.tencent.supersonic.headless.api.pojo.enums.MetricDefineType; +import com.tencent.supersonic.headless.api.pojo.enums.ModelDefineType; import com.tencent.supersonic.headless.api.pojo.enums.SemanticType; import com.tencent.supersonic.headless.api.pojo.request.DimensionReq; import com.tencent.supersonic.headless.api.pojo.request.MetricReq; +import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq; import com.tencent.supersonic.headless.api.pojo.request.ModelReq; import com.tencent.supersonic.headless.api.pojo.response.DomainResp; import com.tencent.supersonic.headless.api.pojo.response.MeasureResp; @@ -156,6 +162,49 @@ public class ModelConverter { return dimensionReq; } + public static ModelReq convert(ModelSchema modelSchema, ModelBuildReq modelBuildReq, + String tableName) { + ModelReq modelReq = new ModelReq(); + modelReq.setName(modelSchema.getName()); + modelReq.setBizName(modelSchema.getBizName()); + modelReq.setDatabaseId(modelBuildReq.getDatabaseId()); + modelReq.setDomainId(modelBuildReq.getDomainId()); + ModelDetail modelDetail = new ModelDetail(); + if (StringUtils.isNotBlank(modelBuildReq.getSql())) { + modelDetail.setQueryType(ModelDefineType.SQL_QUERY.getName()); + modelDetail.setSqlQuery(modelBuildReq.getSql()); + } else { + modelDetail.setQueryType(ModelDefineType.TABLE_QUERY.getName()); + modelDetail.setTableQuery(String.format("%s.%s", modelBuildReq.getDb(), tableName)); + } + for (ColumnSchema columnSchema : modelSchema.getColumnSchemas()) { + FieldType fieldType = columnSchema.getFiledType(); + if (getIdentifyType(fieldType) != null) { + Identify identify = new Identify(columnSchema.getName(), + getIdentifyType(fieldType).name(), columnSchema.getColumnName(), 1); + modelDetail.getIdentifiers().add(identify); + } else if (FieldType.measure.equals(fieldType)) { + Measure measure = new Measure(columnSchema.getName(), columnSchema.getColumnName(), + columnSchema.getAgg().getOperator(), 1); + modelDetail.getMeasures().add(measure); + } else { + Dim dim = new Dim(columnSchema.getName(), columnSchema.getColumnName(), + DimensionType.valueOf(columnSchema.getFiledType().name()), 1); + modelDetail.getDimensions().add(dim); + } + } + modelReq.setModelDetail(modelDetail); + return modelReq; + } + + private static IdentifyType getIdentifyType(FieldType fieldType) { + if (FieldType.foreign_key.equals(fieldType) || FieldType.primary_key.equals(fieldType)) { + return IdentifyType.primary; + } else { + return IdentifyType.foreign; + } + } + public static List convertList(List modelDOS) { List modelDescs = Lists.newArrayList(); if (!CollectionUtils.isEmpty(modelDOS)) {