mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 20:51:48 +00:00
(improvement)(Headless) Supports batch creation of models by specifying db table names (#1833)
Co-authored-by: lxwcodemonkey
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
package com.tencent.supersonic.headless.api.pojo.request;
|
package com.tencent.supersonic.headless.api.pojo.request;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
|
import java.util.List;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@@ -12,9 +13,11 @@ public class ModelSchemaReq {
|
|||||||
|
|
||||||
private String db;
|
private String db;
|
||||||
|
|
||||||
private String table;
|
private List<String> tables;
|
||||||
|
|
||||||
private boolean buildByLLM;
|
private boolean buildByLLM;
|
||||||
|
|
||||||
|
private Integer chatModelId;
|
||||||
|
|
||||||
private ChatModelConfig chatModelConfig;
|
private ChatModelConfig chatModelConfig;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
package com.tencent.supersonic.headless.server.rest;
|
package com.tencent.supersonic.headless.server.rest;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
import javax.servlet.http.HttpServletRequest;
|
import javax.servlet.http.HttpServletRequest;
|
||||||
import javax.servlet.http.HttpServletResponse;
|
import javax.servlet.http.HttpServletResponse;
|
||||||
|
|
||||||
@@ -110,7 +111,7 @@ public class ModelController {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@PostMapping("/buildModelSchema")
|
@PostMapping("/buildModelSchema")
|
||||||
public ModelSchema buildModelSchema(@RequestBody ModelSchemaReq modelSchemaReq)
|
public Map<String, ModelSchema> buildModelSchema(@RequestBody ModelSchemaReq modelSchemaReq)
|
||||||
throws SQLException {
|
throws SQLException {
|
||||||
return modelService.buildModelSchema(modelSchemaReq);
|
return modelService.buildModelSchema(modelSchemaReq);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ public interface DatabaseService {
|
|||||||
|
|
||||||
List<String> getTables(Long id, String db) throws SQLException;
|
List<String> getTables(Long id, String db) throws SQLException;
|
||||||
|
|
||||||
List<DBColumn> getDbColumns(ModelSchemaReq modelSchemaReq) throws SQLException;
|
Map<String, List<DBColumn>> getDbColumns(ModelSchemaReq modelSchemaReq) throws SQLException;
|
||||||
|
|
||||||
List<DBColumn> getColumns(Long id, String db, String table) throws SQLException;
|
List<DBColumn> getColumns(Long id, String db, String table) throws SQLException;
|
||||||
|
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ public interface ModelService {
|
|||||||
|
|
||||||
UnAvailableItemResp getUnAvailableItem(FieldRemovedReq fieldRemovedReq);
|
UnAvailableItemResp getUnAvailableItem(FieldRemovedReq fieldRemovedReq);
|
||||||
|
|
||||||
ModelSchema buildModelSchema(ModelSchemaReq modelSchemaReq) throws SQLException;
|
Map<String, ModelSchema> buildModelSchema(ModelSchemaReq modelSchemaReq) throws SQLException;
|
||||||
|
|
||||||
List<ModelResp> getModelListWithAuth(User user, Long domainId, AuthType authType);
|
List<ModelResp> getModelListWithAuth(User user, Long domainId, AuthType authType);
|
||||||
|
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ import com.tencent.supersonic.headless.server.pojo.ModelFilter;
|
|||||||
import com.tencent.supersonic.headless.server.service.DatabaseService;
|
import com.tencent.supersonic.headless.server.service.DatabaseService;
|
||||||
import com.tencent.supersonic.headless.server.service.ModelService;
|
import com.tencent.supersonic.headless.server.service.ModelService;
|
||||||
import com.tencent.supersonic.headless.server.utils.DatabaseConverter;
|
import com.tencent.supersonic.headless.server.utils.DatabaseConverter;
|
||||||
|
import java.util.HashMap;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
@@ -207,14 +208,19 @@ public class DatabaseServiceImpl extends ServiceImpl<DatabaseDOMapper, DatabaseD
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<DBColumn> getDbColumns(ModelSchemaReq modelSchemaReq) throws SQLException {
|
public Map<String, List<DBColumn>> getDbColumns(ModelSchemaReq modelSchemaReq) throws SQLException {
|
||||||
|
Map<String, List<DBColumn>> dbColumnMap = new HashMap<>();
|
||||||
if (StringUtils.isNotBlank(modelSchemaReq.getSql())) {
|
if (StringUtils.isNotBlank(modelSchemaReq.getSql())) {
|
||||||
return getColumns(modelSchemaReq.getDatabaseId(), modelSchemaReq.getSql());
|
List<DBColumn> columns = getColumns(modelSchemaReq.getDatabaseId(), modelSchemaReq.getSql());
|
||||||
|
dbColumnMap.put(modelSchemaReq.getSql(), columns);
|
||||||
} else {
|
} else {
|
||||||
return getColumns(modelSchemaReq.getDatabaseId(), modelSchemaReq.getDb(),
|
for (String table : modelSchemaReq.getTables()) {
|
||||||
modelSchemaReq.getTable());
|
List<DBColumn> columns = getColumns(modelSchemaReq.getDatabaseId(), modelSchemaReq.getDb(), table);
|
||||||
|
dbColumnMap.put(table, columns);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return dbColumnMap;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<DBColumn> getColumns(Long id, String db, String table) throws SQLException {
|
public List<DBColumn> getColumns(Long id, String db, String table) throws SQLException {
|
||||||
|
|||||||
@@ -1,13 +1,15 @@
|
|||||||
package com.tencent.supersonic.headless.server.service.impl;
|
package com.tencent.supersonic.headless.server.service.impl;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.common.pojo.User;
|
|
||||||
import com.tencent.supersonic.auth.api.authentication.service.UserService;
|
import com.tencent.supersonic.auth.api.authentication.service.UserService;
|
||||||
|
import com.tencent.supersonic.common.config.ChatModel;
|
||||||
import com.tencent.supersonic.common.pojo.ItemDateResp;
|
import com.tencent.supersonic.common.pojo.ItemDateResp;
|
||||||
|
import com.tencent.supersonic.common.pojo.User;
|
||||||
import com.tencent.supersonic.common.pojo.enums.AuthType;
|
import com.tencent.supersonic.common.pojo.enums.AuthType;
|
||||||
import com.tencent.supersonic.common.pojo.enums.EventType;
|
import com.tencent.supersonic.common.pojo.enums.EventType;
|
||||||
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
|
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
|
||||||
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
|
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
|
||||||
|
import com.tencent.supersonic.common.service.ChatModelService;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import com.tencent.supersonic.headless.api.pojo.DBColumn;
|
import com.tencent.supersonic.headless.api.pojo.DBColumn;
|
||||||
import com.tencent.supersonic.headless.api.pojo.DbSchema;
|
import com.tencent.supersonic.headless.api.pojo.DbSchema;
|
||||||
@@ -46,14 +48,6 @@ import com.tencent.supersonic.headless.server.service.MetricService;
|
|||||||
import com.tencent.supersonic.headless.server.service.ModelService;
|
import com.tencent.supersonic.headless.server.service.ModelService;
|
||||||
import com.tencent.supersonic.headless.server.utils.ModelConverter;
|
import com.tencent.supersonic.headless.server.utils.ModelConverter;
|
||||||
import com.tencent.supersonic.headless.server.utils.NameCheckUtils;
|
import com.tencent.supersonic.headless.server.utils.NameCheckUtils;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
import org.springframework.beans.BeanUtils;
|
|
||||||
import org.springframework.context.annotation.Lazy;
|
|
||||||
import org.springframework.stereotype.Service;
|
|
||||||
import org.springframework.transaction.annotation.Transactional;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
|
|
||||||
import java.sql.SQLException;
|
import java.sql.SQLException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
@@ -64,6 +58,13 @@ import java.util.List;
|
|||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.beans.BeanUtils;
|
||||||
|
import org.springframework.context.annotation.Lazy;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@@ -87,11 +88,13 @@ public class ModelServiceImpl implements ModelService {
|
|||||||
|
|
||||||
private ModelIntelligentBuilder modelIntelligentBuilder;
|
private ModelIntelligentBuilder modelIntelligentBuilder;
|
||||||
|
|
||||||
|
private ChatModelService chatModelService;
|
||||||
|
|
||||||
public ModelServiceImpl(ModelRepository modelRepository, DatabaseService databaseService,
|
public ModelServiceImpl(ModelRepository modelRepository, DatabaseService databaseService,
|
||||||
@Lazy DimensionService dimensionService, @Lazy MetricService metricService,
|
@Lazy DimensionService dimensionService, @Lazy MetricService metricService,
|
||||||
DomainService domainService, UserService userService, DataSetService dataSetService,
|
DomainService domainService, UserService userService, DataSetService dataSetService,
|
||||||
DateInfoRepository dateInfoRepository,
|
DateInfoRepository dateInfoRepository,
|
||||||
ModelIntelligentBuilder modelIntelligentBuilder) {
|
ModelIntelligentBuilder modelIntelligentBuilder, ChatModelService chatModelService) {
|
||||||
this.modelRepository = modelRepository;
|
this.modelRepository = modelRepository;
|
||||||
this.databaseService = databaseService;
|
this.databaseService = databaseService;
|
||||||
this.dimensionService = dimensionService;
|
this.dimensionService = dimensionService;
|
||||||
@@ -101,6 +104,7 @@ public class ModelServiceImpl implements ModelService {
|
|||||||
this.dataSetService = dataSetService;
|
this.dataSetService = dataSetService;
|
||||||
this.dateInfoRepository = dateInfoRepository;
|
this.dateInfoRepository = dateInfoRepository;
|
||||||
this.modelIntelligentBuilder = modelIntelligentBuilder;
|
this.modelIntelligentBuilder = modelIntelligentBuilder;
|
||||||
|
this.chatModelService = chatModelService;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -196,22 +200,31 @@ public class ModelServiceImpl implements ModelService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ModelSchema buildModelSchema(ModelSchemaReq modelSchemaReq) throws SQLException {
|
public Map<String, ModelSchema> buildModelSchema(ModelSchemaReq modelSchemaReq) throws SQLException {
|
||||||
List<DBColumn> dbColumns = databaseService.getDbColumns(modelSchemaReq);
|
Map<String, List<DBColumn>> dbColumnMap = databaseService.getDbColumns(modelSchemaReq);
|
||||||
|
Map<String, ModelSchema> modelSchemaMap = new HashMap<>();
|
||||||
if (modelSchemaReq.isBuildByLLM()) {
|
if (modelSchemaReq.isBuildByLLM()) {
|
||||||
DbSchema dbSchema = convert(modelSchemaReq, dbColumns);
|
ChatModel chatModel = chatModelService.getChatModel(modelSchemaReq.getChatModelId());
|
||||||
|
modelSchemaReq.setChatModelConfig(chatModel.getConfig());
|
||||||
|
}
|
||||||
|
for (Map.Entry<String, List<DBColumn>> entry : dbColumnMap.entrySet()) {
|
||||||
|
if (modelSchemaReq.isBuildByLLM()) {
|
||||||
|
DbSchema dbSchema = convert(modelSchemaReq, entry.getKey(), entry.getValue());
|
||||||
ModelSchema modelSchema = modelIntelligentBuilder.build(dbSchema, modelSchemaReq);
|
ModelSchema modelSchema = modelIntelligentBuilder.build(dbSchema, modelSchemaReq);
|
||||||
if (modelSchema != null) {
|
if (modelSchema != null) {
|
||||||
return modelSchema;
|
modelSchemaMap.put(entry.getKey(), modelSchema);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
modelSchemaMap.put(entry.getKey(), build(entry.getValue()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return build(dbColumns);
|
return modelSchemaMap;
|
||||||
}
|
}
|
||||||
|
|
||||||
private DbSchema convert(ModelSchemaReq modelSchemaReq, List<DBColumn> dbColumns) {
|
private DbSchema convert(ModelSchemaReq modelSchemaReq, String key, List<DBColumn> dbColumns) {
|
||||||
DbSchema dbSchema = new DbSchema();
|
DbSchema dbSchema = new DbSchema();
|
||||||
dbSchema.setDb(modelSchemaReq.getDb());
|
dbSchema.setDb(modelSchemaReq.getDb());
|
||||||
dbSchema.setTable(modelSchemaReq.getTable());
|
dbSchema.setTable(key);
|
||||||
dbSchema.setSql(modelSchemaReq.getSql());
|
dbSchema.setSql(modelSchemaReq.getSql());
|
||||||
dbSchema.setDbColumns(dbColumns);
|
dbSchema.setDbColumns(dbColumns);
|
||||||
return dbSchema;
|
return dbSchema;
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import com.tencent.supersonic.common.pojo.User;
|
|||||||
import com.tencent.supersonic.auth.api.authentication.service.UserService;
|
import com.tencent.supersonic.auth.api.authentication.service.UserService;
|
||||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||||
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
|
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
|
||||||
|
import com.tencent.supersonic.common.service.ChatModelService;
|
||||||
import com.tencent.supersonic.headless.api.pojo.Dim;
|
import com.tencent.supersonic.headless.api.pojo.Dim;
|
||||||
import com.tencent.supersonic.headless.api.pojo.DimensionTimeTypeParams;
|
import com.tencent.supersonic.headless.api.pojo.DimensionTimeTypeParams;
|
||||||
import com.tencent.supersonic.headless.api.pojo.Identify;
|
import com.tencent.supersonic.headless.api.pojo.Identify;
|
||||||
@@ -78,9 +79,10 @@ class ModelServiceImplTest {
|
|||||||
DataSetService viewService = Mockito.mock(DataSetService.class);
|
DataSetService viewService = Mockito.mock(DataSetService.class);
|
||||||
ModelIntelligentBuilder modelIntelligentBuilder =
|
ModelIntelligentBuilder modelIntelligentBuilder =
|
||||||
Mockito.mock(ModelIntelligentBuilder.class);
|
Mockito.mock(ModelIntelligentBuilder.class);
|
||||||
|
ChatModelService chatModelService = Mockito.mock(ChatModelService.class);
|
||||||
return new ModelServiceImpl(modelRepository, databaseService, dimensionService,
|
return new ModelServiceImpl(modelRepository, databaseService, dimensionService,
|
||||||
metricService, domainService, userService, viewService, dateInfoRepository,
|
metricService, domainService, userService, viewService, dateInfoRepository,
|
||||||
modelIntelligentBuilder);
|
modelIntelligentBuilder, chatModelService);
|
||||||
}
|
}
|
||||||
|
|
||||||
private ModelReq mockModelReq() {
|
private ModelReq mockModelReq() {
|
||||||
|
|||||||
Reference in New Issue
Block a user