From bd82b0904b61279711b140a7bffd87a9d774bc93 Mon Sep 17 00:00:00 2001 From: LXW <1264174498@qq.com> Date: Sun, 27 Oct 2024 18:16:59 +0800 Subject: [PATCH] (improvement)(Headless) Add integration testing for building data-model by LLM (#1848) Co-authored-by: lxwcodemonkey --- .../headless/api/pojo/FieldSchema.java | 4 + .../headless/api/pojo/ModelSchema.java | 11 +++ ...ModelSchemaReq.java => ModelBuildReq.java} | 2 +- .../builder/ModelIntelligentBuilder.java | 33 ++++--- .../server/rest/DatabaseController.java | 6 +- .../headless/server/rest/ModelController.java | 6 +- .../server/service/DatabaseService.java | 4 +- .../headless/server/service/ModelService.java | 4 +- .../service/impl/DatabaseServiceImpl.java | 14 +-- .../server/service/impl/ModelServiceImpl.java | 85 +++++++++++++----- .../headless/ModelIntelligentBuildTest.java | 90 +++++++++++++++++++ .../src/test/resources/application-local.yaml | 2 +- 12 files changed, 211 insertions(+), 50 deletions(-) rename headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/{ModelSchemaReq.java => ModelBuildReq.java} (92%) create mode 100644 launchers/standalone/src/test/java/com/tencent/supersonic/headless/ModelIntelligentBuildTest.java diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/FieldSchema.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/FieldSchema.java index 12f4aff82..2fcb41b6e 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/FieldSchema.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/FieldSchema.java @@ -1,5 +1,6 @@ package com.tencent.supersonic.headless.api.pojo; +import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; import com.tencent.supersonic.headless.api.pojo.enums.FieldType; import lombok.Data; @@ -14,5 +15,8 @@ public class FieldSchema { private FieldType filedType; + private AggOperatorEnum agg; + private String name; + } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/ModelSchema.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/ModelSchema.java index 1eafb8a44..781d10bd7 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/ModelSchema.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/ModelSchema.java @@ -1,5 +1,6 @@ package com.tencent.supersonic.headless.api.pojo; +import com.fasterxml.jackson.annotation.JsonIgnore; import lombok.Data; import java.util.List; @@ -15,4 +16,14 @@ public class ModelSchema { private List filedSchemas; + @JsonIgnore + public FieldSchema getFieldByName(String columnName) { + for (FieldSchema fieldSchema : filedSchemas) { + if (fieldSchema.getColumnName().equalsIgnoreCase(columnName)) { + return fieldSchema; + } + } + return null; + } + } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/ModelSchemaReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/ModelBuildReq.java similarity index 92% rename from headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/ModelSchemaReq.java rename to headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/ModelBuildReq.java index ab8547ace..38299696a 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/ModelSchemaReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/ModelBuildReq.java @@ -6,7 +6,7 @@ import lombok.Data; import java.util.List; @Data -public class ModelSchemaReq { +public class ModelBuildReq { private Long databaseId; diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/builder/ModelIntelligentBuilder.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/builder/ModelIntelligentBuilder.java index 9c6da8fcf..d2c37dcb4 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/builder/ModelIntelligentBuilder.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/builder/ModelIntelligentBuilder.java @@ -7,16 +7,19 @@ import com.tencent.supersonic.common.util.ChatAppManager; import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.headless.api.pojo.DbSchema; import com.tencent.supersonic.headless.api.pojo.ModelSchema; -import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaReq; +import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq; import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.service.AiServices; +import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Component; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Optional; +@Slf4j @Component public class ModelIntelligentBuilder extends IntelligentBuilder { @@ -33,10 +36,13 @@ public class ModelIntelligentBuilder extends IntelligentBuilder { + "\n primary_key: This is a unique identifier for a record row in a database." + "\n foreign_key: This is a key in a database whose value is derived from the primary key of another table." + "\n data_time: This represents the time when data is generated in the data warehouse." - + "\n dimension: Usually a string type, used for grouping and filtering data." - + "\n measure: Usually a numeric type, used to quantify data from a certain evaluative perspective." - - + "\nDBSchema: {{DBSchema}}" + "\nExemplar: {{exemplar}}"; + + "\n dimension: Usually a string type, used for grouping and filtering data. No need to generate aggregate functions" + + "\n measure: Usually a numeric type, used to quantify data from a certain evaluative perspective. " + + " Also, you need to generate aggregate functions(Eg: MAX, MIN, AVG, SUM, COUNT) for the measure type. " + + "\nTip: I will also give you other related dbSchemas. If you determine that different dbSchemas have the same fields, " + + " they can be primary and foreign key relationships." + + "\nDBSchema: {{DBSchema}}" + "\nOtherRelatedDBSchema: {{otherRelatedDBSchema}}" + + "\nExemplar: {{exemplar}}"; public ModelIntelligentBuilder() { ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("构造数据语义模型") @@ -49,22 +55,29 @@ public class ModelIntelligentBuilder extends IntelligentBuilder { } - public ModelSchema build(DbSchema dbSchema, ModelSchemaReq modelSchemaReq) { + public ModelSchema build(DbSchema dbSchema, List otherDbSchema, + ModelBuildReq modelBuildReq) { Optional chatApp = ChatAppManager.getApp(APP_KEY); if (!chatApp.isPresent() || !chatApp.get().isEnable()) { return null; } - ChatModelConfig chatModelConfig = modelSchemaReq.getChatModelConfig(); + ChatModelConfig chatModelConfig = modelBuildReq.getChatModelConfig(); ModelSchemaExtractor extractor = AiServices.create(ModelSchemaExtractor.class, getChatModel(chatModelConfig)); - Prompt prompt = generatePrompt(dbSchema, chatApp.get()); - return extractor.generateModelSchema(prompt.toUserMessage().singleText()); + Prompt prompt = generatePrompt(dbSchema, otherDbSchema, chatApp.get()); + ModelSchema modelSchema = + extractor.generateModelSchema(prompt.toUserMessage().singleText()); + log.info("dbSchema: {} modelSchema: {}", JsonUtil.toString(dbSchema), + JsonUtil.toString(modelSchema)); + return modelSchema; } - private Prompt generatePrompt(DbSchema dbSchema, ChatApp chatApp) { + private Prompt generatePrompt(DbSchema dbSchema, List otherDbSchema, + ChatApp chatApp) { Map variable = new HashMap<>(); variable.put("exemplar", loadExemplars()); variable.put("DBSchema", JsonUtil.toString(dbSchema)); + variable.put("otherRelatedDBSchema", JsonUtil.toString(otherDbSchema)); return PromptTemplate.from(chatApp.getPrompt()).apply(variable); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/DatabaseController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/DatabaseController.java index bd7e42f16..01d7fd188 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/DatabaseController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/DatabaseController.java @@ -7,7 +7,7 @@ import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.headless.api.pojo.DBColumn; import com.tencent.supersonic.headless.api.pojo.request.DatabaseReq; -import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaReq; +import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq; import com.tencent.supersonic.headless.api.pojo.request.SqlExecuteReq; import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; @@ -96,9 +96,9 @@ public class DatabaseController { } @PostMapping("/listColumnsBySql") - public List listColumnsBySql(@RequestBody ModelSchemaReq modelSchemaReq) + public List listColumnsBySql(@RequestBody ModelBuildReq modelBuildReq) throws SQLException { - return databaseService.getColumns(modelSchemaReq.getDatabaseId(), modelSchemaReq.getSql()); + return databaseService.getColumns(modelBuildReq.getDatabaseId(), modelBuildReq.getSql()); } @GetMapping("/getDatabaseParameters") diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/ModelController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/ModelController.java index fb2fd67b4..0605101c0 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/ModelController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/ModelController.java @@ -10,8 +10,8 @@ import com.tencent.supersonic.common.pojo.enums.AuthType; import com.tencent.supersonic.headless.api.pojo.ModelSchema; import com.tencent.supersonic.headless.api.pojo.request.FieldRemovedReq; import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq; +import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq; import com.tencent.supersonic.headless.api.pojo.request.ModelReq; -import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaReq; import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp; import com.tencent.supersonic.headless.api.pojo.response.ModelResp; import com.tencent.supersonic.headless.api.pojo.response.UnAvailableItemResp; @@ -111,8 +111,8 @@ public class ModelController { } @PostMapping("/buildModelSchema") - public Map buildModelSchema(@RequestBody ModelSchemaReq modelSchemaReq) + public Map buildModelSchema(@RequestBody ModelBuildReq modelBuildReq) throws SQLException { - return modelService.buildModelSchema(modelSchemaReq); + return modelService.buildModelSchema(modelBuildReq); } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/DatabaseService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/DatabaseService.java index 661111a6e..8d6811a6c 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/DatabaseService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/DatabaseService.java @@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.server.service; import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.headless.api.pojo.DBColumn; import com.tencent.supersonic.headless.api.pojo.request.DatabaseReq; -import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaReq; +import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq; import com.tencent.supersonic.headless.api.pojo.request.SqlExecuteReq; import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; @@ -37,7 +37,7 @@ public interface DatabaseService { List getTables(Long id, String db) throws SQLException; - Map> getDbColumns(ModelSchemaReq modelSchemaReq) throws SQLException; + Map> getDbColumns(ModelBuildReq modelBuildReq) throws SQLException; List getColumns(Long id, String db, String table) throws SQLException; diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/ModelService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/ModelService.java index 3b55aa6cf..203884b4d 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/ModelService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/ModelService.java @@ -8,8 +8,8 @@ import com.tencent.supersonic.headless.api.pojo.MetaFilter; import com.tencent.supersonic.headless.api.pojo.ModelSchema; import com.tencent.supersonic.headless.api.pojo.request.FieldRemovedReq; import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq; +import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq; import com.tencent.supersonic.headless.api.pojo.request.ModelReq; -import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaReq; import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp; import com.tencent.supersonic.headless.api.pojo.response.ModelResp; import com.tencent.supersonic.headless.api.pojo.response.UnAvailableItemResp; @@ -35,7 +35,7 @@ public interface ModelService { UnAvailableItemResp getUnAvailableItem(FieldRemovedReq fieldRemovedReq); - Map buildModelSchema(ModelSchemaReq modelSchemaReq) throws SQLException; + Map buildModelSchema(ModelBuildReq modelBuildReq) throws SQLException; List getModelListWithAuth(User user, Long domainId, AuthType authType); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DatabaseServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DatabaseServiceImpl.java index a4cc51f56..975847918 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DatabaseServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DatabaseServiceImpl.java @@ -7,7 +7,7 @@ import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.common.pojo.enums.EngineType; import com.tencent.supersonic.headless.api.pojo.DBColumn; import com.tencent.supersonic.headless.api.pojo.request.DatabaseReq; -import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaReq; +import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq; import com.tencent.supersonic.headless.api.pojo.request.SqlExecuteReq; import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp; import com.tencent.supersonic.headless.api.pojo.response.ModelResp; @@ -208,17 +208,17 @@ public class DatabaseServiceImpl extends ServiceImpl> getDbColumns(ModelSchemaReq modelSchemaReq) + public Map> getDbColumns(ModelBuildReq modelBuildReq) throws SQLException { Map> dbColumnMap = new HashMap<>(); - if (StringUtils.isNotBlank(modelSchemaReq.getSql())) { + if (StringUtils.isNotBlank(modelBuildReq.getSql())) { List columns = - getColumns(modelSchemaReq.getDatabaseId(), modelSchemaReq.getSql()); - dbColumnMap.put(modelSchemaReq.getSql(), columns); + getColumns(modelBuildReq.getDatabaseId(), modelBuildReq.getSql()); + dbColumnMap.put(modelBuildReq.getSql(), columns); } else { - for (String table : modelSchemaReq.getTables()) { + for (String table : modelBuildReq.getTables()) { List columns = - getColumns(modelSchemaReq.getDatabaseId(), modelSchemaReq.getDb(), table); + getColumns(modelBuildReq.getDatabaseId(), modelBuildReq.getDb(), table); dbColumnMap.put(table, columns); } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java index ff8f01367..b2dc364e7 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java @@ -26,8 +26,8 @@ import com.tencent.supersonic.headless.api.pojo.request.DimensionReq; import com.tencent.supersonic.headless.api.pojo.request.FieldRemovedReq; import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq; import com.tencent.supersonic.headless.api.pojo.request.MetricReq; +import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq; import com.tencent.supersonic.headless.api.pojo.request.ModelReq; -import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaReq; import com.tencent.supersonic.headless.api.pojo.response.DataSetResp; import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp; import com.tencent.supersonic.headless.api.pojo.response.DimensionResp; @@ -41,7 +41,13 @@ import com.tencent.supersonic.headless.server.persistence.dataobject.ModelDO; import com.tencent.supersonic.headless.server.persistence.repository.DateInfoRepository; import com.tencent.supersonic.headless.server.persistence.repository.ModelRepository; import com.tencent.supersonic.headless.server.pojo.ModelFilter; -import com.tencent.supersonic.headless.server.service.*; +import com.tencent.supersonic.headless.server.service.DataSetService; +import com.tencent.supersonic.headless.server.service.DatabaseService; +import com.tencent.supersonic.headless.server.service.DimensionService; +import com.tencent.supersonic.headless.server.service.DomainService; +import com.tencent.supersonic.headless.server.service.MetricService; +import com.tencent.supersonic.headless.server.service.ModelRelaService; +import com.tencent.supersonic.headless.server.service.ModelService; import com.tencent.supersonic.headless.server.utils.ModelConverter; import com.tencent.supersonic.headless.server.utils.NameCheckUtils; import lombok.extern.slf4j.Slf4j; @@ -53,7 +59,21 @@ import org.springframework.transaction.annotation.Transactional; import org.springframework.util.CollectionUtils; import java.sql.SQLException; -import java.util.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.Date; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; @Service @@ -82,6 +102,9 @@ public class ModelServiceImpl implements ModelService { private ModelRelaService modelRelaService; + ExecutorService executor = + new ThreadPoolExecutor(0, 5, 100L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>()); + public ModelServiceImpl(ModelRepository modelRepository, DatabaseService databaseService, @Lazy DimensionService dimensionService, @Lazy MetricService metricService, DomainService domainService, UserService userService, DataSetService dataSetService, @@ -194,29 +217,49 @@ public class ModelServiceImpl implements ModelService { } @Override - public Map buildModelSchema(ModelSchemaReq modelSchemaReq) + public Map buildModelSchema(ModelBuildReq modelBuildReq) throws SQLException { - Map> dbColumnMap = databaseService.getDbColumns(modelSchemaReq); - Map modelSchemaMap = new HashMap<>(); - if (modelSchemaReq.isBuildByLLM()) { - ChatModel chatModel = chatModelService.getChatModel(modelSchemaReq.getChatModelId()); - modelSchemaReq.setChatModelConfig(chatModel.getConfig()); - } - for (Map.Entry> entry : dbColumnMap.entrySet()) { - if (modelSchemaReq.isBuildByLLM()) { - DbSchema dbSchema = convert(modelSchemaReq, entry.getKey(), entry.getValue()); - ModelSchema modelSchema = modelIntelligentBuilder.build(dbSchema, modelSchemaReq); - if (modelSchema != null) { - modelSchemaMap.put(entry.getKey(), modelSchema); - } - } else { - modelSchemaMap.put(entry.getKey(), build(entry.getValue())); - } + Map> dbColumnMap = databaseService.getDbColumns(modelBuildReq); + if (modelBuildReq.isBuildByLLM() && modelBuildReq.getChatModelConfig() == null) { + ChatModel chatModel = chatModelService.getChatModel(modelBuildReq.getChatModelId()); + modelBuildReq.setChatModelConfig(chatModel.getConfig()); } + List dbSchemas = convert(dbColumnMap, modelBuildReq); + Map modelSchemaMap = new ConcurrentHashMap<>(); + CompletableFuture.allOf(dbSchemas.stream() + .map(dbSchema -> CompletableFuture.runAsync( + () -> doBuild(modelBuildReq, dbSchema, dbSchemas, modelSchemaMap), + executor)) + .toArray(CompletableFuture[]::new)).join(); return modelSchemaMap; } - private DbSchema convert(ModelSchemaReq modelSchemaReq, String key, List dbColumns) { + private void doBuild(ModelBuildReq modelBuildReq, DbSchema curSchema, List dbSchemas, + Map modelSchemaMap) { + if (modelBuildReq.isBuildByLLM()) { + List otherDbSchema = getOtherDbSchema(curSchema, dbSchemas); + ModelSchema modelSchema = + modelIntelligentBuilder.build(curSchema, otherDbSchema, modelBuildReq); + modelSchemaMap.put(curSchema.getTable(), modelSchema); + } else { + modelSchemaMap.put(curSchema.getTable(), build(curSchema.getDbColumns())); + } + } + + private List getOtherDbSchema(DbSchema curSchema, List dbSchemas) { + return dbSchemas.stream() + .filter(dbSchema -> !dbSchema.getTable().equals(curSchema.getTable())) + .collect(Collectors.toList()); + } + + private List convert(Map> dbColumnMap, + ModelBuildReq modelSchemaReq) { + return dbColumnMap.keySet().stream() + .map(key -> convert(modelSchemaReq, key, dbColumnMap.get(key))) + .collect(Collectors.toList()); + } + + private DbSchema convert(ModelBuildReq modelSchemaReq, String key, List dbColumns) { DbSchema dbSchema = new DbSchema(); dbSchema.setDb(modelSchemaReq.getDb()); dbSchema.setTable(key); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/ModelIntelligentBuildTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/ModelIntelligentBuildTest.java new file mode 100644 index 000000000..b383c1ad4 --- /dev/null +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/ModelIntelligentBuildTest.java @@ -0,0 +1,90 @@ +package com.tencent.supersonic.headless; + + +import com.google.common.collect.Lists; +import com.tencent.supersonic.common.pojo.ChatModelConfig; +import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; +import com.tencent.supersonic.headless.api.pojo.ModelSchema; +import com.tencent.supersonic.headless.api.pojo.enums.FieldType; +import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq; +import com.tencent.supersonic.headless.server.service.ModelService; +import com.tencent.supersonic.util.LLMConfigUtils; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; + +import java.sql.SQLException; +import java.util.Map; + +@Disabled +public class ModelIntelligentBuildTest extends BaseTest { + + private LLMConfigUtils.LLMType llmType = LLMConfigUtils.LLMType.OLLAMA_LLAMA3; + + @Autowired + private ModelService modelService; + + @Test + public void testBuildModelBatch() throws SQLException { + ChatModelConfig llmConfig = LLMConfigUtils.getLLMConfig(llmType); + ModelBuildReq modelSchemaReq = new ModelBuildReq(); + modelSchemaReq.setChatModelConfig(llmConfig); + modelSchemaReq.setBuildByLLM(true); + modelSchemaReq.setDatabaseId(1L); + modelSchemaReq.setDb("semantic"); + modelSchemaReq.setTables(Lists.newArrayList("s2_user_department", "s2_stay_time_statis")); + Map modelSchemaMap = modelService.buildModelSchema(modelSchemaReq); + + ModelSchema userModelSchema = modelSchemaMap.get("s2_user_department"); + Assertions.assertEquals(2, userModelSchema.getFiledSchemas().size()); + Assertions.assertEquals(FieldType.primary_key, + userModelSchema.getFieldByName("user_name").getFiledType()); + Assertions.assertEquals(FieldType.dimension, + userModelSchema.getFieldByName("department").getFiledType()); + + ModelSchema stayTimeModelSchema = modelSchemaMap.get("s2_stay_time_statis"); + Assertions.assertEquals(4, stayTimeModelSchema.getFiledSchemas().size()); + Assertions.assertEquals(FieldType.foreign_key, + stayTimeModelSchema.getFieldByName("user_name").getFiledType()); + Assertions.assertEquals(FieldType.data_time, + stayTimeModelSchema.getFieldByName("imp_date").getFiledType()); + Assertions.assertEquals(FieldType.dimension, + stayTimeModelSchema.getFieldByName("page").getFiledType()); + Assertions.assertEquals(FieldType.measure, + stayTimeModelSchema.getFieldByName("stay_hours").getFiledType()); + Assertions.assertEquals(AggOperatorEnum.SUM, + stayTimeModelSchema.getFieldByName("stay_hours").getAgg()); + } + + + @Test + public void testBuildModelBySql() throws SQLException { + ChatModelConfig llmConfig = LLMConfigUtils.getLLMConfig(llmType); + ModelBuildReq modelSchemaReq = new ModelBuildReq(); + modelSchemaReq.setChatModelConfig(llmConfig); + modelSchemaReq.setBuildByLLM(true); + modelSchemaReq.setDatabaseId(1L); + modelSchemaReq.setDb("semantic"); + modelSchemaReq.setSql( + "SELECT imp_date, user_name, page, 1 as pv, user_name as uv FROM s2_pv_uv_statis"); + Map modelSchemaMap = modelService.buildModelSchema(modelSchemaReq); + + ModelSchema pvModelSchema = modelSchemaMap.values().iterator().next(); + Assertions.assertEquals(5, pvModelSchema.getFiledSchemas().size()); + Assertions.assertEquals(FieldType.data_time, + pvModelSchema.getFieldByName("imp_date").getFiledType()); + Assertions.assertEquals(FieldType.dimension, + pvModelSchema.getFieldByName("user_name").getFiledType()); + Assertions.assertEquals(FieldType.dimension, + pvModelSchema.getFieldByName("page").getFiledType()); + Assertions.assertEquals(FieldType.measure, + pvModelSchema.getFieldByName("pv").getFiledType()); + Assertions.assertEquals(AggOperatorEnum.SUM, pvModelSchema.getFieldByName("pv").getAgg()); + Assertions.assertEquals(FieldType.measure, + pvModelSchema.getFieldByName("uv").getFiledType()); + Assertions.assertEquals(AggOperatorEnum.COUNT_DISTINCT, + pvModelSchema.getFieldByName("uv").getAgg()); + } + +} diff --git a/launchers/standalone/src/test/resources/application-local.yaml b/launchers/standalone/src/test/resources/application-local.yaml index 650ca3a44..5f3e99e22 100644 --- a/launchers/standalone/src/test/resources/application-local.yaml +++ b/launchers/standalone/src/test/resources/application-local.yaml @@ -1,7 +1,7 @@ spring: datasource: driver-class-name: org.h2.Driver - url: jdbc:h2:mem:semantic;DATABASE_TO_UPPER=false;QUERY_TIMEOUT=30 + url: jdbc:h2:mem:semantic;DATABASE_TO_UPPER=false;QUERY_TIMEOUT=100 username: root password: semantic sql: