mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(Headless) Supports batch creation of models by specifying db table names (#1833)
Co-authored-by: lxwcodemonkey
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.headless.api.pojo.request;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
@@ -12,9 +13,11 @@ public class ModelSchemaReq {
|
||||
|
||||
private String db;
|
||||
|
||||
private String table;
|
||||
private List<String> tables;
|
||||
|
||||
private boolean buildByLLM;
|
||||
|
||||
private Integer chatModelId;
|
||||
|
||||
private ChatModelConfig chatModelConfig;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.tencent.supersonic.headless.server.rest;
|
||||
|
||||
import java.util.Map;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
|
||||
@@ -110,7 +111,7 @@ public class ModelController {
|
||||
}
|
||||
|
||||
@PostMapping("/buildModelSchema")
|
||||
public ModelSchema buildModelSchema(@RequestBody ModelSchemaReq modelSchemaReq)
|
||||
public Map<String, ModelSchema> buildModelSchema(@RequestBody ModelSchemaReq modelSchemaReq)
|
||||
throws SQLException {
|
||||
return modelService.buildModelSchema(modelSchemaReq);
|
||||
}
|
||||
|
||||
@@ -37,7 +37,7 @@ public interface DatabaseService {
|
||||
|
||||
List<String> getTables(Long id, String db) throws SQLException;
|
||||
|
||||
List<DBColumn> getDbColumns(ModelSchemaReq modelSchemaReq) throws SQLException;
|
||||
Map<String, List<DBColumn>> getDbColumns(ModelSchemaReq modelSchemaReq) throws SQLException;
|
||||
|
||||
List<DBColumn> getColumns(Long id, String db, String table) throws SQLException;
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ public interface ModelService {
|
||||
|
||||
UnAvailableItemResp getUnAvailableItem(FieldRemovedReq fieldRemovedReq);
|
||||
|
||||
ModelSchema buildModelSchema(ModelSchemaReq modelSchemaReq) throws SQLException;
|
||||
Map<String, ModelSchema> buildModelSchema(ModelSchemaReq modelSchemaReq) throws SQLException;
|
||||
|
||||
List<ModelResp> getModelListWithAuth(User user, Long domainId, AuthType authType);
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ import com.tencent.supersonic.headless.server.pojo.ModelFilter;
|
||||
import com.tencent.supersonic.headless.server.service.DatabaseService;
|
||||
import com.tencent.supersonic.headless.server.service.ModelService;
|
||||
import com.tencent.supersonic.headless.server.utils.DatabaseConverter;
|
||||
import java.util.HashMap;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
@@ -207,13 +208,18 @@ public class DatabaseServiceImpl extends ServiceImpl<DatabaseDOMapper, DatabaseD
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DBColumn> getDbColumns(ModelSchemaReq modelSchemaReq) throws SQLException {
|
||||
public Map<String, List<DBColumn>> getDbColumns(ModelSchemaReq modelSchemaReq) throws SQLException {
|
||||
Map<String, List<DBColumn>> dbColumnMap = new HashMap<>();
|
||||
if (StringUtils.isNotBlank(modelSchemaReq.getSql())) {
|
||||
return getColumns(modelSchemaReq.getDatabaseId(), modelSchemaReq.getSql());
|
||||
List<DBColumn> columns = getColumns(modelSchemaReq.getDatabaseId(), modelSchemaReq.getSql());
|
||||
dbColumnMap.put(modelSchemaReq.getSql(), columns);
|
||||
} else {
|
||||
return getColumns(modelSchemaReq.getDatabaseId(), modelSchemaReq.getDb(),
|
||||
modelSchemaReq.getTable());
|
||||
for (String table : modelSchemaReq.getTables()) {
|
||||
List<DBColumn> columns = getColumns(modelSchemaReq.getDatabaseId(), modelSchemaReq.getDb(), table);
|
||||
dbColumnMap.put(table, columns);
|
||||
}
|
||||
}
|
||||
return dbColumnMap;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
package com.tencent.supersonic.headless.server.service.impl;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.auth.api.authentication.service.UserService;
|
||||
import com.tencent.supersonic.common.config.ChatModel;
|
||||
import com.tencent.supersonic.common.pojo.ItemDateResp;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.AuthType;
|
||||
import com.tencent.supersonic.common.pojo.enums.EventType;
|
||||
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
|
||||
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
|
||||
import com.tencent.supersonic.common.service.ChatModelService;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.DBColumn;
|
||||
import com.tencent.supersonic.headless.api.pojo.DbSchema;
|
||||
@@ -46,14 +48,6 @@ import com.tencent.supersonic.headless.server.service.MetricService;
|
||||
import com.tencent.supersonic.headless.server.service.ModelService;
|
||||
import com.tencent.supersonic.headless.server.utils.ModelConverter;
|
||||
import com.tencent.supersonic.headless.server.utils.NameCheckUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.context.annotation.Lazy;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.sql.SQLException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
@@ -64,6 +58,13 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.context.annotation.Lazy;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
@@ -87,11 +88,13 @@ public class ModelServiceImpl implements ModelService {
|
||||
|
||||
private ModelIntelligentBuilder modelIntelligentBuilder;
|
||||
|
||||
private ChatModelService chatModelService;
|
||||
|
||||
public ModelServiceImpl(ModelRepository modelRepository, DatabaseService databaseService,
|
||||
@Lazy DimensionService dimensionService, @Lazy MetricService metricService,
|
||||
DomainService domainService, UserService userService, DataSetService dataSetService,
|
||||
DateInfoRepository dateInfoRepository,
|
||||
ModelIntelligentBuilder modelIntelligentBuilder) {
|
||||
ModelIntelligentBuilder modelIntelligentBuilder, ChatModelService chatModelService) {
|
||||
this.modelRepository = modelRepository;
|
||||
this.databaseService = databaseService;
|
||||
this.dimensionService = dimensionService;
|
||||
@@ -101,6 +104,7 @@ public class ModelServiceImpl implements ModelService {
|
||||
this.dataSetService = dataSetService;
|
||||
this.dateInfoRepository = dateInfoRepository;
|
||||
this.modelIntelligentBuilder = modelIntelligentBuilder;
|
||||
this.chatModelService = chatModelService;
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -196,22 +200,31 @@ public class ModelServiceImpl implements ModelService {
|
||||
}
|
||||
|
||||
@Override
|
||||
public ModelSchema buildModelSchema(ModelSchemaReq modelSchemaReq) throws SQLException {
|
||||
List<DBColumn> dbColumns = databaseService.getDbColumns(modelSchemaReq);
|
||||
public Map<String, ModelSchema> buildModelSchema(ModelSchemaReq modelSchemaReq) throws SQLException {
|
||||
Map<String, List<DBColumn>> dbColumnMap = databaseService.getDbColumns(modelSchemaReq);
|
||||
Map<String, ModelSchema> modelSchemaMap = new HashMap<>();
|
||||
if (modelSchemaReq.isBuildByLLM()) {
|
||||
DbSchema dbSchema = convert(modelSchemaReq, dbColumns);
|
||||
ModelSchema modelSchema = modelIntelligentBuilder.build(dbSchema, modelSchemaReq);
|
||||
if (modelSchema != null) {
|
||||
return modelSchema;
|
||||
ChatModel chatModel = chatModelService.getChatModel(modelSchemaReq.getChatModelId());
|
||||
modelSchemaReq.setChatModelConfig(chatModel.getConfig());
|
||||
}
|
||||
for (Map.Entry<String, List<DBColumn>> entry : dbColumnMap.entrySet()) {
|
||||
if (modelSchemaReq.isBuildByLLM()) {
|
||||
DbSchema dbSchema = convert(modelSchemaReq, entry.getKey(), entry.getValue());
|
||||
ModelSchema modelSchema = modelIntelligentBuilder.build(dbSchema, modelSchemaReq);
|
||||
if (modelSchema != null) {
|
||||
modelSchemaMap.put(entry.getKey(), modelSchema);
|
||||
}
|
||||
} else {
|
||||
modelSchemaMap.put(entry.getKey(), build(entry.getValue()));
|
||||
}
|
||||
}
|
||||
return build(dbColumns);
|
||||
return modelSchemaMap;
|
||||
}
|
||||
|
||||
private DbSchema convert(ModelSchemaReq modelSchemaReq, List<DBColumn> dbColumns) {
|
||||
private DbSchema convert(ModelSchemaReq modelSchemaReq, String key, List<DBColumn> dbColumns) {
|
||||
DbSchema dbSchema = new DbSchema();
|
||||
dbSchema.setDb(modelSchemaReq.getDb());
|
||||
dbSchema.setTable(modelSchemaReq.getTable());
|
||||
dbSchema.setTable(key);
|
||||
dbSchema.setSql(modelSchemaReq.getSql());
|
||||
dbSchema.setDbColumns(dbColumns);
|
||||
return dbSchema;
|
||||
|
||||
@@ -5,6 +5,7 @@ import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.auth.api.authentication.service.UserService;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
|
||||
import com.tencent.supersonic.common.service.ChatModelService;
|
||||
import com.tencent.supersonic.headless.api.pojo.Dim;
|
||||
import com.tencent.supersonic.headless.api.pojo.DimensionTimeTypeParams;
|
||||
import com.tencent.supersonic.headless.api.pojo.Identify;
|
||||
@@ -78,9 +79,10 @@ class ModelServiceImplTest {
|
||||
DataSetService viewService = Mockito.mock(DataSetService.class);
|
||||
ModelIntelligentBuilder modelIntelligentBuilder =
|
||||
Mockito.mock(ModelIntelligentBuilder.class);
|
||||
ChatModelService chatModelService = Mockito.mock(ChatModelService.class);
|
||||
return new ModelServiceImpl(modelRepository, databaseService, dimensionService,
|
||||
metricService, domainService, userService, viewService, dateInfoRepository,
|
||||
modelIntelligentBuilder);
|
||||
modelIntelligentBuilder, chatModelService);
|
||||
}
|
||||
|
||||
private ModelReq mockModelReq() {
|
||||
|
||||
Reference in New Issue
Block a user