diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/ModelSchemaReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/ModelSchemaReq.java index 85f31aa0d..87486f837 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/ModelSchemaReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/ModelSchemaReq.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.headless.api.pojo.request; import com.tencent.supersonic.common.pojo.ChatModelConfig; +import java.util.List; import lombok.Data; @Data @@ -12,9 +13,11 @@ public class ModelSchemaReq { private String db; - private String table; + private List tables; private boolean buildByLLM; + private Integer chatModelId; + private ChatModelConfig chatModelConfig; } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/ModelController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/ModelController.java index 497aa97eb..7162e6dba 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/ModelController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/ModelController.java @@ -1,5 +1,6 @@ package com.tencent.supersonic.headless.server.rest; +import java.util.Map; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -110,7 +111,7 @@ public class ModelController { } @PostMapping("/buildModelSchema") - public ModelSchema buildModelSchema(@RequestBody ModelSchemaReq modelSchemaReq) + public Map buildModelSchema(@RequestBody ModelSchemaReq modelSchemaReq) throws SQLException { return modelService.buildModelSchema(modelSchemaReq); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/DatabaseService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/DatabaseService.java index b6e2ecbd9..661111a6e 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/DatabaseService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/DatabaseService.java @@ -37,7 +37,7 @@ public interface DatabaseService { List getTables(Long id, String db) throws SQLException; - List getDbColumns(ModelSchemaReq modelSchemaReq) throws SQLException; + Map> getDbColumns(ModelSchemaReq modelSchemaReq) throws SQLException; List getColumns(Long id, String db, String table) throws SQLException; diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/ModelService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/ModelService.java index 6c385f753..de72e31c0 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/ModelService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/ModelService.java @@ -35,7 +35,7 @@ public interface ModelService { UnAvailableItemResp getUnAvailableItem(FieldRemovedReq fieldRemovedReq); - ModelSchema buildModelSchema(ModelSchemaReq modelSchemaReq) throws SQLException; + Map buildModelSchema(ModelSchemaReq modelSchemaReq) throws SQLException; List getModelListWithAuth(User user, Long domainId, AuthType authType); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DatabaseServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DatabaseServiceImpl.java index 8b9f165e6..a8a0df228 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DatabaseServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DatabaseServiceImpl.java @@ -28,6 +28,7 @@ import com.tencent.supersonic.headless.server.pojo.ModelFilter; import com.tencent.supersonic.headless.server.service.DatabaseService; import com.tencent.supersonic.headless.server.service.ModelService; import com.tencent.supersonic.headless.server.utils.DatabaseConverter; +import java.util.HashMap; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.annotation.Autowired; @@ -207,13 +208,18 @@ public class DatabaseServiceImpl extends ServiceImpl getDbColumns(ModelSchemaReq modelSchemaReq) throws SQLException { + public Map> getDbColumns(ModelSchemaReq modelSchemaReq) throws SQLException { + Map> dbColumnMap = new HashMap<>(); if (StringUtils.isNotBlank(modelSchemaReq.getSql())) { - return getColumns(modelSchemaReq.getDatabaseId(), modelSchemaReq.getSql()); + List columns = getColumns(modelSchemaReq.getDatabaseId(), modelSchemaReq.getSql()); + dbColumnMap.put(modelSchemaReq.getSql(), columns); } else { - return getColumns(modelSchemaReq.getDatabaseId(), modelSchemaReq.getDb(), - modelSchemaReq.getTable()); + for (String table : modelSchemaReq.getTables()) { + List columns = getColumns(modelSchemaReq.getDatabaseId(), modelSchemaReq.getDb(), table); + dbColumnMap.put(table, columns); + } } + return dbColumnMap; } @Override diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java index c7fb7402c..02d549ae0 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java @@ -1,13 +1,15 @@ package com.tencent.supersonic.headless.server.service.impl; import com.google.common.collect.Lists; -import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.auth.api.authentication.service.UserService; +import com.tencent.supersonic.common.config.ChatModel; import com.tencent.supersonic.common.pojo.ItemDateResp; +import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.common.pojo.enums.AuthType; import com.tencent.supersonic.common.pojo.enums.EventType; import com.tencent.supersonic.common.pojo.enums.StatusEnum; import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException; +import com.tencent.supersonic.common.service.ChatModelService; import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.headless.api.pojo.DBColumn; import com.tencent.supersonic.headless.api.pojo.DbSchema; @@ -46,14 +48,6 @@ import com.tencent.supersonic.headless.server.service.MetricService; import com.tencent.supersonic.headless.server.service.ModelService; import com.tencent.supersonic.headless.server.utils.ModelConverter; import com.tencent.supersonic.headless.server.utils.NameCheckUtils; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.StringUtils; -import org.springframework.beans.BeanUtils; -import org.springframework.context.annotation.Lazy; -import org.springframework.stereotype.Service; -import org.springframework.transaction.annotation.Transactional; -import org.springframework.util.CollectionUtils; - import java.sql.SQLException; import java.util.ArrayList; import java.util.Comparator; @@ -64,6 +58,13 @@ import java.util.List; import java.util.Map; import java.util.Set; import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.BeanUtils; +import org.springframework.context.annotation.Lazy; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; +import org.springframework.util.CollectionUtils; @Service @Slf4j @@ -87,11 +88,13 @@ public class ModelServiceImpl implements ModelService { private ModelIntelligentBuilder modelIntelligentBuilder; + private ChatModelService chatModelService; + public ModelServiceImpl(ModelRepository modelRepository, DatabaseService databaseService, @Lazy DimensionService dimensionService, @Lazy MetricService metricService, DomainService domainService, UserService userService, DataSetService dataSetService, DateInfoRepository dateInfoRepository, - ModelIntelligentBuilder modelIntelligentBuilder) { + ModelIntelligentBuilder modelIntelligentBuilder, ChatModelService chatModelService) { this.modelRepository = modelRepository; this.databaseService = databaseService; this.dimensionService = dimensionService; @@ -101,6 +104,7 @@ public class ModelServiceImpl implements ModelService { this.dataSetService = dataSetService; this.dateInfoRepository = dateInfoRepository; this.modelIntelligentBuilder = modelIntelligentBuilder; + this.chatModelService = chatModelService; } @Override @@ -196,22 +200,31 @@ public class ModelServiceImpl implements ModelService { } @Override - public ModelSchema buildModelSchema(ModelSchemaReq modelSchemaReq) throws SQLException { - List dbColumns = databaseService.getDbColumns(modelSchemaReq); + public Map buildModelSchema(ModelSchemaReq modelSchemaReq) throws SQLException { + Map> dbColumnMap = databaseService.getDbColumns(modelSchemaReq); + Map modelSchemaMap = new HashMap<>(); if (modelSchemaReq.isBuildByLLM()) { - DbSchema dbSchema = convert(modelSchemaReq, dbColumns); - ModelSchema modelSchema = modelIntelligentBuilder.build(dbSchema, modelSchemaReq); - if (modelSchema != null) { - return modelSchema; + ChatModel chatModel = chatModelService.getChatModel(modelSchemaReq.getChatModelId()); + modelSchemaReq.setChatModelConfig(chatModel.getConfig()); + } + for (Map.Entry> entry : dbColumnMap.entrySet()) { + if (modelSchemaReq.isBuildByLLM()) { + DbSchema dbSchema = convert(modelSchemaReq, entry.getKey(), entry.getValue()); + ModelSchema modelSchema = modelIntelligentBuilder.build(dbSchema, modelSchemaReq); + if (modelSchema != null) { + modelSchemaMap.put(entry.getKey(), modelSchema); + } + } else { + modelSchemaMap.put(entry.getKey(), build(entry.getValue())); } } - return build(dbColumns); + return modelSchemaMap; } - private DbSchema convert(ModelSchemaReq modelSchemaReq, List dbColumns) { + private DbSchema convert(ModelSchemaReq modelSchemaReq, String key, List dbColumns) { DbSchema dbSchema = new DbSchema(); dbSchema.setDb(modelSchemaReq.getDb()); - dbSchema.setTable(modelSchemaReq.getTable()); + dbSchema.setTable(key); dbSchema.setSql(modelSchemaReq.getSql()); dbSchema.setDbColumns(dbColumns); return dbSchema; diff --git a/headless/server/src/test/java/com/tencent/supersonic/headless/server/service/ModelServiceImplTest.java b/headless/server/src/test/java/com/tencent/supersonic/headless/server/service/ModelServiceImplTest.java index 11eb1835e..40dc01ada 100644 --- a/headless/server/src/test/java/com/tencent/supersonic/headless/server/service/ModelServiceImplTest.java +++ b/headless/server/src/test/java/com/tencent/supersonic/headless/server/service/ModelServiceImplTest.java @@ -5,6 +5,7 @@ import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.auth.api.authentication.service.UserService; import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; import com.tencent.supersonic.common.pojo.enums.StatusEnum; +import com.tencent.supersonic.common.service.ChatModelService; import com.tencent.supersonic.headless.api.pojo.Dim; import com.tencent.supersonic.headless.api.pojo.DimensionTimeTypeParams; import com.tencent.supersonic.headless.api.pojo.Identify; @@ -78,9 +79,10 @@ class ModelServiceImplTest { DataSetService viewService = Mockito.mock(DataSetService.class); ModelIntelligentBuilder modelIntelligentBuilder = Mockito.mock(ModelIntelligentBuilder.class); + ChatModelService chatModelService = Mockito.mock(ChatModelService.class); return new ModelServiceImpl(modelRepository, databaseService, dimensionService, metricService, domainService, userService, viewService, dateInfoRepository, - modelIntelligentBuilder); + modelIntelligentBuilder, chatModelService); } private ModelReq mockModelReq() {