(improvement)(Headless) Supports batch creation of models by specifying db table names (#1833)

Co-authored-by: lxwcodemonkey
This commit is contained in:
LXW
2024-10-20 20:17:11 +08:00
committed by GitHub
parent 1f208ffcb7
commit 1d84e00887
7 changed files with 53 additions and 28 deletions

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.headless.api.pojo.request;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import java.util.List;
import lombok.Data;
@Data
@@ -12,9 +13,11 @@ public class ModelSchemaReq {
private String db;
private String table;
private List<String> tables;
private boolean buildByLLM;
private Integer chatModelId;
private ChatModelConfig chatModelConfig;
}

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.headless.server.rest;
import java.util.Map;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
@@ -110,7 +111,7 @@ public class ModelController {
}
@PostMapping("/buildModelSchema")
public ModelSchema buildModelSchema(@RequestBody ModelSchemaReq modelSchemaReq)
public Map<String, ModelSchema> buildModelSchema(@RequestBody ModelSchemaReq modelSchemaReq)
throws SQLException {
return modelService.buildModelSchema(modelSchemaReq);
}

View File

@@ -37,7 +37,7 @@ public interface DatabaseService {
List<String> getTables(Long id, String db) throws SQLException;
List<DBColumn> getDbColumns(ModelSchemaReq modelSchemaReq) throws SQLException;
Map<String, List<DBColumn>> getDbColumns(ModelSchemaReq modelSchemaReq) throws SQLException;
List<DBColumn> getColumns(Long id, String db, String table) throws SQLException;

View File

@@ -35,7 +35,7 @@ public interface ModelService {
UnAvailableItemResp getUnAvailableItem(FieldRemovedReq fieldRemovedReq);
ModelSchema buildModelSchema(ModelSchemaReq modelSchemaReq) throws SQLException;
Map<String, ModelSchema> buildModelSchema(ModelSchemaReq modelSchemaReq) throws SQLException;
List<ModelResp> getModelListWithAuth(User user, Long domainId, AuthType authType);

View File

@@ -28,6 +28,7 @@ import com.tencent.supersonic.headless.server.pojo.ModelFilter;
import com.tencent.supersonic.headless.server.service.DatabaseService;
import com.tencent.supersonic.headless.server.service.ModelService;
import com.tencent.supersonic.headless.server.utils.DatabaseConverter;
import java.util.HashMap;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
@@ -207,13 +208,18 @@ public class DatabaseServiceImpl extends ServiceImpl<DatabaseDOMapper, DatabaseD
}
@Override
public List<DBColumn> getDbColumns(ModelSchemaReq modelSchemaReq) throws SQLException {
public Map<String, List<DBColumn>> getDbColumns(ModelSchemaReq modelSchemaReq) throws SQLException {
Map<String, List<DBColumn>> dbColumnMap = new HashMap<>();
if (StringUtils.isNotBlank(modelSchemaReq.getSql())) {
return getColumns(modelSchemaReq.getDatabaseId(), modelSchemaReq.getSql());
List<DBColumn> columns = getColumns(modelSchemaReq.getDatabaseId(), modelSchemaReq.getSql());
dbColumnMap.put(modelSchemaReq.getSql(), columns);
} else {
return getColumns(modelSchemaReq.getDatabaseId(), modelSchemaReq.getDb(),
modelSchemaReq.getTable());
for (String table : modelSchemaReq.getTables()) {
List<DBColumn> columns = getColumns(modelSchemaReq.getDatabaseId(), modelSchemaReq.getDb(), table);
dbColumnMap.put(table, columns);
}
}
return dbColumnMap;
}
@Override

View File

@@ -1,13 +1,15 @@
package com.tencent.supersonic.headless.server.service.impl;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.auth.api.authentication.service.UserService;
import com.tencent.supersonic.common.config.ChatModel;
import com.tencent.supersonic.common.pojo.ItemDateResp;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.pojo.enums.AuthType;
import com.tencent.supersonic.common.pojo.enums.EventType;
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
import com.tencent.supersonic.common.service.ChatModelService;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.DBColumn;
import com.tencent.supersonic.headless.api.pojo.DbSchema;
@@ -46,14 +48,6 @@ import com.tencent.supersonic.headless.server.service.MetricService;
import com.tencent.supersonic.headless.server.service.ModelService;
import com.tencent.supersonic.headless.server.utils.ModelConverter;
import com.tencent.supersonic.headless.server.utils.NameCheckUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.CollectionUtils;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Comparator;
@@ -64,6 +58,13 @@ import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.CollectionUtils;
@Service
@Slf4j
@@ -87,11 +88,13 @@ public class ModelServiceImpl implements ModelService {
private ModelIntelligentBuilder modelIntelligentBuilder;
private ChatModelService chatModelService;
public ModelServiceImpl(ModelRepository modelRepository, DatabaseService databaseService,
@Lazy DimensionService dimensionService, @Lazy MetricService metricService,
DomainService domainService, UserService userService, DataSetService dataSetService,
DateInfoRepository dateInfoRepository,
ModelIntelligentBuilder modelIntelligentBuilder) {
ModelIntelligentBuilder modelIntelligentBuilder, ChatModelService chatModelService) {
this.modelRepository = modelRepository;
this.databaseService = databaseService;
this.dimensionService = dimensionService;
@@ -101,6 +104,7 @@ public class ModelServiceImpl implements ModelService {
this.dataSetService = dataSetService;
this.dateInfoRepository = dateInfoRepository;
this.modelIntelligentBuilder = modelIntelligentBuilder;
this.chatModelService = chatModelService;
}
@Override
@@ -196,22 +200,31 @@ public class ModelServiceImpl implements ModelService {
}
@Override
public ModelSchema buildModelSchema(ModelSchemaReq modelSchemaReq) throws SQLException {
List<DBColumn> dbColumns = databaseService.getDbColumns(modelSchemaReq);
public Map<String, ModelSchema> buildModelSchema(ModelSchemaReq modelSchemaReq) throws SQLException {
Map<String, List<DBColumn>> dbColumnMap = databaseService.getDbColumns(modelSchemaReq);
Map<String, ModelSchema> modelSchemaMap = new HashMap<>();
if (modelSchemaReq.isBuildByLLM()) {
DbSchema dbSchema = convert(modelSchemaReq, dbColumns);
ModelSchema modelSchema = modelIntelligentBuilder.build(dbSchema, modelSchemaReq);
if (modelSchema != null) {
return modelSchema;
ChatModel chatModel = chatModelService.getChatModel(modelSchemaReq.getChatModelId());
modelSchemaReq.setChatModelConfig(chatModel.getConfig());
}
for (Map.Entry<String, List<DBColumn>> entry : dbColumnMap.entrySet()) {
if (modelSchemaReq.isBuildByLLM()) {
DbSchema dbSchema = convert(modelSchemaReq, entry.getKey(), entry.getValue());
ModelSchema modelSchema = modelIntelligentBuilder.build(dbSchema, modelSchemaReq);
if (modelSchema != null) {
modelSchemaMap.put(entry.getKey(), modelSchema);
}
} else {
modelSchemaMap.put(entry.getKey(), build(entry.getValue()));
}
}
return build(dbColumns);
return modelSchemaMap;
}
private DbSchema convert(ModelSchemaReq modelSchemaReq, List<DBColumn> dbColumns) {
private DbSchema convert(ModelSchemaReq modelSchemaReq, String key, List<DBColumn> dbColumns) {
DbSchema dbSchema = new DbSchema();
dbSchema.setDb(modelSchemaReq.getDb());
dbSchema.setTable(modelSchemaReq.getTable());
dbSchema.setTable(key);
dbSchema.setSql(modelSchemaReq.getSql());
dbSchema.setDbColumns(dbColumns);
return dbSchema;

View File

@@ -5,6 +5,7 @@ import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.auth.api.authentication.service.UserService;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import com.tencent.supersonic.common.service.ChatModelService;
import com.tencent.supersonic.headless.api.pojo.Dim;
import com.tencent.supersonic.headless.api.pojo.DimensionTimeTypeParams;
import com.tencent.supersonic.headless.api.pojo.Identify;
@@ -78,9 +79,10 @@ class ModelServiceImplTest {
DataSetService viewService = Mockito.mock(DataSetService.class);
ModelIntelligentBuilder modelIntelligentBuilder =
Mockito.mock(ModelIntelligentBuilder.class);
ChatModelService chatModelService = Mockito.mock(ChatModelService.class);
return new ModelServiceImpl(modelRepository, databaseService, dimensionService,
metricService, domainService, userService, viewService, dateInfoRepository,
modelIntelligentBuilder);
modelIntelligentBuilder, chatModelService);
}
private ModelReq mockModelReq() {